[
  {
    "path": ".dockerignore",
    "content": "data/\r\n.git\r\n.cache"
  },
  {
    "path": ".gitignore",
    "content": "data/\nlogs/\npackages/\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n"
  },
  {
    "path": ".gitmodules",
    "content": "[submodule \"docker/containers\"]\n\tpath = docker/containers\n\turl = https://github.com/dusty-nv/jetson-containers\n"
  },
  {
    "path": "Dockerfile.aarch64",
    "content": "# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.\r\n#\r\n# Permission is hereby granted, free of charge, to any person obtaining a\r\n# copy of this software and associated documentation files (the \"Software\"),\r\n# to deal in the Software without restriction, including without limitation\r\n# the rights to use, copy, modify, merge, publish, distribute, sublicense,\r\n# and/or sell copies of the Software, and to permit persons to whom the\r\n# Software is furnished to do so, subject to the following conditions:\r\n#\r\n# The above copyright notice and this permission notice shall be included in\r\n# all copies or substantial portions of the Software.\r\n#\r\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\r\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\r\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL\r\n# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\r\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\r\n# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\r\n# DEALINGS IN THE SOFTWARE.\r\n\r\nARG BASE_IMAGE\r\nFROM ${BASE_IMAGE}\r\n\r\nENV DEBIAN_FRONTEND=noninteractive\r\nENV SHELL /bin/bash\r\nENV LANG='en_US.UTF-8' LANGUAGE='en_US:en' LC_ALL='en_US.UTF-8'\r\nARG MAKEFLAGS=-j$(nproc)\r\nARG WORKSPACE=/jetson-voice\r\n\r\nWORKDIR ${WORKSPACE}\r\n\r\n# alias python3 -> python\r\nRUN rm /usr/bin/python && \\\r\n    ln -s /usr/bin/python3 /usr/bin/python && \\\r\n    ln -s /usr/bin/pip3 /usr/bin/pip\r\n\r\n\r\n################################################################\r\n## tokenizers/transformers\r\n################################################################\r\nRUN apt-get update && \\\r\n    apt-get install -y --no-install-recommends \\\r\n            cmake \\\r\n\t\t  curl \\\r\n\t\t  pkg-config \\\r\n\t\t  protobuf-compiler \\\r\n\t\t  libprotoc-dev \\\r\n\t\t  nano \\\r\n\t\t  tzdata \\\r\n\t\t  libssl-dev \\\r\n    && rm -rf /var/lib/apt/lists/* \\\r\n    && apt-get clean\r\n    \r\n# install sentencepiece\r\nRUN git clone https://github.com/google/sentencepiece && \\\r\n\tcd sentencepiece && \\\r\n\tmkdir build && \\\r\n\tcd build && \\\r\n\tcmake .. && \\\r\n\tmake -j $(nproc) && \\\r\n\tmake install && \\\r\n\tldconfig -v && \\\r\n\tcd .. && \\\r\n\tcd python && \\\r\n\tpython3 setup.py install --verbose && \\\r\n\tcd ../../ && \\\r\n\trm -r -f sentencepiece\r\n\r\n# install rust (used by tokenizers)\r\nRUN curl https://sh.rustup.rs -sSf | sh -s -- -y\r\nENV PATH=\"/root/.cargo/bin:${PATH}\"\r\nRUN rustc --version && \\\r\n    pip3 install setuptools-rust\r\n\r\n# install tokenizers\r\nRUN pip3 install tokenizers --verbose\r\n\r\n# Apache arrow is needed by datasets package ('pip install pyarrow' is broken, so built from source)\r\n#  https://github.com/apache/arrow/blob/master/docs/source/developers/python.rst#using-pip\r\n#  https://raspberrypi.stackexchange.com/a/117723\r\nRUN apt-get update && \\\r\n    apt-get install -y --no-install-recommends \\\r\n\t\t\tlibjemalloc-dev \\\r\n\t\t\tlibboost-dev \\\r\n\t\t\tlibboost-filesystem-dev \\\r\n\t\t\tlibboost-system-dev \\\r\n\t\t\tlibboost-regex-dev \\\r\n\t\t\tautoconf \\\r\n\t\t\tflex \\\r\n\t\t\tbison \\\r\n    && rm -rf /var/lib/apt/lists/* \\\r\n    && apt-get clean \r\n\r\nRUN git clone --branch apache-arrow-3.0.0 https://github.com/apache/arrow.git && \\\r\n\tcd arrow/cpp && \\\r\n\tmkdir build && \\\r\n\tcd build && \\\r\n\texport ARROW_HOME=/usr/local && \\\r\n\tcmake \\\r\n\t\t-DCMAKE_INSTALL_PREFIX=$ARROW_HOME \\\r\n\t\t-DCMAKE_INSTALL_LIBDIR=lib \\\r\n\t\t-DARROW_WITH_BZ2=ON \\\r\n\t\t-DARROW_WITH_ZLIB=ON \\\r\n\t\t-DARROW_WITH_ZSTD=ON \\\r\n\t\t-DARROW_WITH_LZ4=ON \\\r\n\t\t-DARROW_WITH_SNAPPY=ON \\\r\n\t\t-DARROW_PARQUET=ON \\\r\n\t\t-DARROW_CUDA=ON \\\r\n\t\t-DARROW_PYTHON=ON \\\r\n\t\t-DARROW_BUILD_TESTS=OFF \\\r\n\t\t.. && \\\r\n\tmake -j$(nproc) && \\\r\n\tmake install && \\\r\n\tcd ../../python && \\\r\n\tpython3 setup.py build_ext --build-type=release --with-parquet --with-cuda --verbose && \\\r\n\tpython3 setup.py install --verbose && \\\r\n\tcd ../../ && \\\r\n\trm -r -f arrow\r\n\r\nRUN pip3 show pyarrow && \\\r\n\tpython3 -c \"import pyarrow\" && \\\r\n\tpython3 -c \"from pyarrow import cuda\"\r\n\t\r\n# install huggingface (locked to 4.5.1, which the patches are based on)\r\n# datasets package is needed to run the huggingface examples\r\nRUN pip3 install transformers==4.5.1 datasets --verbose\r\n  \r\n\r\n################################################################\r\n## onnx / onnxruntime / onnx-graphsurgeon\r\n################################################################\r\nARG ONNXRUNTIME_URL=https://nvidia.box.com/shared/static/ukszbm1iklzymrt54mgxbzjfzunq7i9t.whl\r\nARG ONNXRUNTIME_WHL=onnxruntime_gpu-1.7.0-cp36-cp36m-linux_aarch64.whl\r\n\r\nRUN wget --quiet --show-progress --progress=bar:force:noscroll --no-check-certificate ${ONNXRUNTIME_URL} -O ${ONNXRUNTIME_WHL} && \\\r\n    pip3 install ${ONNXRUNTIME_WHL} --verbose && \\\r\n    pip3 install onnx psutil sympy --verbose && \\\r\n    rm ${ONNXRUNTIME_WHL}\r\n\r\n# install onnx-graphsurgeon\r\nRUN cd /opt && \\\r\n    git clone --recursive https://github.com/nvidia/tensorrt tensorrt && \\\r\n    cd tensorrt/tools/onnx-graphsurgeon && \\\r\n    python3 setup.py install --verbose && \\\r\n    cd ../../../ && \\\r\n    rm -r -f tensorrt\r\n    \r\n    \r\n################################################################\r\n## NeMo\r\n################################################################\r\nRUN apt-get update && \\\r\n    apt-get install -y --no-install-recommends \\\r\n\t\t  libopencc-dev \\\r\n\t\t  python3-tk \\\r\n\t\t  libmecab-dev \\\r\n\t\t  mecab \\\r\n    && rm -rf /var/lib/apt/lists/* \\\r\n    && apt-get clean\r\n    \r\nRUN cd /opt && \\\r\n    git clone --recursive --branch v0.11.1 https://github.com/pytorch/text torchtext && \\\r\n    cd torchtext && \\\r\n    python3 setup.py clean install \r\n    \r\nRUN pip3 show torch torchvision torchaudio torchtext\r\n\r\n# clone/build nemo\r\nARG NEMO_VERSION\r\nRUN cd /opt && git clone --recursive --branch v${NEMO_VERSION} https://github.com/nvidia/nemo\r\n\r\n# needed for nemo 1.0\r\n#COPY patches/nemo/${NEMO_VERSION}/setup.py /opt/nemo/setup.py\r\n\r\n# needed for nemo 1.6\r\nCOPY patches/nemo/${NEMO_VERSION}/requirements.txt /opt/nemo/requirements/requirements.txt\r\nCOPY patches/nemo/${NEMO_VERSION}/requirements_nlp.txt /opt/nemo/requirements/requirements_nlp.txt\r\n\r\nRUN pip3 install -r /opt/nemo/requirements/requirements.txt --verbose\r\nRUN pip3 install -r /opt/nemo/requirements/requirements_asr.txt --verbose\r\nRUN pip3 install -r /opt/nemo/requirements/requirements_nlp.txt --verbose\r\nRUN pip3 install -r /opt/nemo/requirements/requirements_tts.txt --verbose\r\n#RUN pip3 install omegaconf==2.1.0dev24 --verbose\r\n\r\nRUN cd /opt/nemo && python3 setup.py install --verbose\r\n\r\n\r\n################################################################\r\n## ctc-decoders\r\n################################################################\r\nRUN apt-get update && \\\r\n    apt-get install -y --no-install-recommends \\\r\n\t\t  swig \\\r\n    && rm -rf /var/lib/apt/lists/* \\\r\n    && apt-get clean\r\n    \r\nRUN git clone https://github.com/dusty-nv/OpenSeq2Seq -b ctc-decoders && \\\r\n    cd OpenSeq2Seq/decoders && \\\r\n    ./setup.sh\r\n    \r\nRUN pip3 install git+https://github.com/NVIDIA/dllogger\r\nRUN pip3 install nltk\r\n\r\n\r\n################################################################\r\n## Riva GRPC\r\n################################################################\r\nARG RIVA_URL=https://nvidia.box.com/shared/static/cu8z4t1n6shkxl6z5nh9hpkpn9yxomcz.whl\r\nARG RIVA_WHL=riva_api-1.0.0ea-py3-none-any.whl\r\n\r\nRUN wget --quiet --show-progress --progress=bar:force:noscroll --no-check-certificate ${RIVA_URL} -O ${RIVA_WHL} && \\\r\n    pip3 install ${RIVA_WHL} --verbose && \\\r\n    rm ${RIVA_WHL}\r\n\r\n\r\n################################################################\r\n## install some audio stuff\r\n################################################################\r\nRUN apt-get update && \\\r\n    apt-get install -y --no-install-recommends \\\r\n\t\t  alsa-base \\\r\n            libasound2-dev \\\r\n            alsa-utils \\\r\n            portaudio19-dev \\\r\n\t\t  libsndfile1 \\\r\n\t\t  unzip \\\r\n    && rm -rf /var/lib/apt/lists/* \\\r\n    && apt-get clean\r\n    \r\nRUN pip3 install soundfile pyaudio wave\r\n\r\n\r\n################################################################\r\n## various patches to install\r\n################################################################\r\n#COPY patches patches\r\n\r\n#RUN PYTHON_ROOT=`pip3 show torch | grep Location: | cut -d' ' -f2` && \\\r\n#    PYTORCH_VERSION=`pip3 show torch | grep Version: | cut -d' ' -f2` && \\\r\n#    TRANSFORMERS_VERSION=`pip3 show transformers | grep Version: | cut -d' ' -f2` && \\\r\n#    NEMO_PATH=\"$PYTHON_ROOT/nemo_toolkit-${NEMO_VERSION}-py3.6.egg/nemo\" && \\\r\n#    echo \"Python package root path:  $PYTHON_ROOT\" && \\\r\n#    echo \"Applying patches for PyTorch $PYTORCH_VERSION\" && \\\r\n#    echo \"Applying patches for transformers $TRANSFORMERS_VERSION\" && \\\r\n#    cp patches/pytorch/$PYTORCH_VERSION/functional.py $PYTHON_ROOT/torch/functional.py && \\\r\n#    cp patches/transformers/$TRANSFORMERS_VERSION/convert_graph_to_onnx.py $PYTHON_ROOT/transformers/convert_graph_to_onnx.py && \\\r\n#    cp patches/transformers/$TRANSFORMERS_VERSION/modeling_distilbert.py $PYTHON_ROOT/transformers/models/distilbert/modeling_distilbert.py && \\\r\n#    cp patches/nemo/${NEMO_VERSION}/nlp/distilbert.py $NEMO_PATH/collections/nlp/modules/common/huggingface/distilbert.py && \\\r\n#    cp patches/nemo/${NEMO_VERSION}/exportable.py $NEMO_PATH/core/classes/exportable.py\r\n\r\n\r\n# set Python to unicode\r\nENV PYTHONIOENCODING=utf-8\r\n\r\n# disable JupyterLab from auto-starting (inherited behavior from l4t-ml)\r\nCMD /bin/bash\r\n"
  },
  {
    "path": "Dockerfile.ros",
    "content": "# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.\r\n#\r\n# Permission is hereby granted, free of charge, to any person obtaining a\r\n# copy of this software and associated documentation files (the \"Software\"),\r\n# to deal in the Software without restriction, including without limitation\r\n# the rights to use, copy, modify, merge, publish, distribute, sublicense,\r\n# and/or sell copies of the Software, and to permit persons to whom the\r\n# Software is furnished to do so, subject to the following conditions:\r\n#\r\n# The above copyright notice and this permission notice shall be included in\r\n# all copies or substantial portions of the Software.\r\n#\r\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\r\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\r\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL\r\n# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\r\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\r\n# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\r\n# DEALINGS IN THE SOFTWARE.\r\n\r\nARG BASE_IMAGE=jetson-voice:r32.5.0-foxy-base\r\nFROM ${BASE_IMAGE}\r\n\r\n\r\n################################################################\r\n## install jetson_voice_ros package\r\n################################################################\r\nCOPY ros /tmp/jetson_voice_ros\r\n    \r\nRUN source ${ROS_ROOT}/install/setup.bash && \\\r\n    mkdir -p ${ROS_ROOT}/src && \\\r\n    cd ${ROS_ROOT} && \\\r\n    cp -r /tmp/jetson_voice_ros src && \\\r\n    \r\n    # build the package\r\n    colcon build \\\r\n        --merge-install \\\r\n\t    --base-paths src/jetson_voice_ros \\\r\n        --event-handlers console_direct+ && \\\r\n\t  \r\n    # clean-up build files\r\n    rm -rf ${ROS_ROOT}/src && \\\r\n    rm -rf ${ROS_ROOT}/logs && \\\r\n    rm -rf ${ROS_ROOT}/build\r\n\r\n\r\n################################################################\r\n## project install\r\n################################################################\r\nARG WORKSPACE=/jetson-voice\r\n\r\nCOPY jetson_voice ${WORKSPACE}/jetson_voice\r\nCOPY examples ${WORKSPACE}/examples\r\nCOPY scripts ${WORKSPACE}/scripts\r\nCOPY tests ${WORKSPACE}/tests\r\n\r\nENV PYTHONPATH=\"${WORKSPACE}:${PYTHONPATH}\"\r\n"
  },
  {
    "path": "Dockerfile.runtime",
    "content": "# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.\r\n#\r\n# Permission is hereby granted, free of charge, to any person obtaining a\r\n# copy of this software and associated documentation files (the \"Software\"),\r\n# to deal in the Software without restriction, including without limitation\r\n# the rights to use, copy, modify, merge, publish, distribute, sublicense,\r\n# and/or sell copies of the Software, and to permit persons to whom the\r\n# Software is furnished to do so, subject to the following conditions:\r\n#\r\n# The above copyright notice and this permission notice shall be included in\r\n# all copies or substantial portions of the Software.\r\n#\r\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\r\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\r\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL\r\n# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\r\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\r\n# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\r\n# DEALINGS IN THE SOFTWARE.\r\n\r\nARG BASE_IMAGE=jetson-voice:r32.5.0-base\r\nFROM ${BASE_IMAGE}\r\n\r\nARG WORKSPACE=/jetson-voice\r\nWORKDIR ${WORKSPACE}\r\n\r\n\r\n################################################################\r\n## project install\r\n################################################################\r\nCOPY jetson_voice ${WORKSPACE}/jetson_voice\r\nCOPY examples ${WORKSPACE}/examples\r\nCOPY scripts ${WORKSPACE}/scripts\r\nCOPY tests ${WORKSPACE}/tests\r\n\r\nENV PYTHONPATH=\"${WORKSPACE}:${PYTHONPATH}\""
  },
  {
    "path": "Dockerfile.x86_64",
    "content": "# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.\r\n#\r\n# Permission is hereby granted, free of charge, to any person obtaining a\r\n# copy of this software and associated documentation files (the \"Software\"),\r\n# to deal in the Software without restriction, including without limitation\r\n# the rights to use, copy, modify, merge, publish, distribute, sublicense,\r\n# and/or sell copies of the Software, and to permit persons to whom the\r\n# Software is furnished to do so, subject to the following conditions:\r\n#\r\n# The above copyright notice and this permission notice shall be included in\r\n# all copies or substantial portions of the Software.\r\n#\r\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\r\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\r\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL\r\n# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\r\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\r\n# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\r\n# DEALINGS IN THE SOFTWARE.\r\n\r\nARG BASE_IMAGE\r\nFROM ${BASE_IMAGE}\r\n\r\nENV DEBIAN_FRONTEND=noninteractive\r\nENV SHELL /bin/bash\r\nARG MAKEFLAGS=-j$(nproc)\r\nARG WORKSPACE=/jetson-voice\r\n\r\nWORKDIR ${WORKSPACE}\r\n\r\n\r\n################################################################\r\n## PyCUDA\r\n################################################################\r\nRUN pip3 install pycuda six --verbose\r\n\r\n\r\n################################################################\r\n## ctc-decoders\r\n################################################################\r\nRUN apt-get update && \\\r\n    apt-get install -y --no-install-recommends \\\r\n\t\t  swig \\\r\n    && rm -rf /var/lib/apt/lists/* \\\r\n    && apt-get clean\r\n    \r\nRUN git clone https://github.com/dusty-nv/OpenSeq2Seq -b ctc-decoders && \\\r\n    cd OpenSeq2Seq/decoders && \\\r\n    ./setup.sh\r\n    \r\nRUN pip3 install git+https://github.com/NVIDIA/dllogger\r\nRUN pip3 install nltk\r\n\r\n\r\n################################################################\r\n## Jarvis GRPC\r\n################################################################\r\nARG JARVIS_URL=https://nvidia.box.com/shared/static/on9t7zqes2s6er1wpumidnc6rphwsyy7.whl\r\nARG JARVIS_WHL=jarvis_api-1.0.0b1-py3-none-any.whl\r\n\r\nRUN wget --quiet --show-progress --progress=bar:force:noscroll --no-check-certificate ${JARVIS_URL} -O ${JARVIS_WHL} && \\\r\n    pip3 install ${JARVIS_WHL} --verbose && \\\r\n    rm ${JARVIS_WHL}\r\n    \r\n    \r\n################################################################\r\n## install some audio stuff\r\n################################################################\r\nRUN apt-get update && \\\r\n    apt-get install -y --no-install-recommends \\\r\n\t\t  alsa-base \\\r\n            libasound2-dev \\\r\n            alsa-utils \\\r\n            portaudio19-dev \\\r\n\t\t  libsndfile1 \\\r\n\t\t  unzip \\\r\n\t\t  tzdata \\\r\n\t\t  nano \\\r\n    && rm -rf /var/lib/apt/lists/* \\\r\n    && apt-get clean\r\n    \r\nRUN pip3 install soundfile pyaudio wave\r\n\r\n\r\n################################################################\r\n## various patches to install\r\n################################################################\r\nCOPY patches patches\r\n\r\nARG NEMO_VERSION\r\nRUN PYTHON_ROOT=`pip3 show transformers | grep Location: | cut -d' ' -f2` && \\\r\n    TRANSFORMERS_VERSION=`pip3 show transformers | grep Version: | cut -d' ' -f2` && \\\r\n    echo \"Python package root path:  $PYTHON_ROOT\" && \\\r\n    echo \"Applying patches for transformers $TRANSFORMERS_VERSION\" && \\\r\n    cp patches/transformers/$TRANSFORMERS_VERSION/convert_graph_to_onnx.py $PYTHON_ROOT/transformers/convert_graph_to_onnx.py && \\\r\n    cp patches/transformers/$TRANSFORMERS_VERSION/modeling_distilbert.py $PYTHON_ROOT/transformers/models/distilbert/modeling_distilbert.py && \\\r\n    cp patches/nemo/${NEMO_VERSION}/nlp/distilbert.py $PYTHON_ROOT/nemo/collections/nlp/modules/common/huggingface/distilbert.py && \\\r\n    cp patches/nemo/${NEMO_VERSION}/exportable.py $PYTHON_ROOT/nemo/core/classes/exportable.py\r\n\r\n\r\n# set Python to unicode\r\nENV PYTHONIOENCODING=utf-8\r\n    "
  },
  {
    "path": "README.md",
    "content": "# jetson-voice\r\n\r\njetson-voice is an ASR/NLP/TTS deep learning inference library for Jetson Nano, TX1/TX2, Xavier NX, and AGX Xavier.  It supports Python and JetPack 4.4.1 or newer.  The DNN models were trained with [NeMo](https://github.com/NVIDIA/NeMo) and deployed with [TensorRT](https://developer.nvidia.com/tensorrt) for optimized performance.  All computation is performed using the onboard GPU.\r\n\r\nCurrently the following capabilities are included:\r\n\r\n* [Automatic Speech Recognition (ASR)](#automatic-speech-recognition-asr)\r\n\t* [Streaming ASR (QuartzNet)](#automatic-speech-recognition-asr) \r\n\t* [Command/Keyword Recognition (MatchboxNet)](#commandkeyword-recognition)\r\n\t* [Voice Activity Detection (VAD Marblenet)](#voice-activity-detection-vad)\r\n* [Natural Language Processing (NLP)](#natural-language-processing-nlp)\r\n\t* [Joint Intent/Slot Classification](#joint-intentslot-classification)\r\n\t* [Text Classification (Sentiment Analysis)](#text-classification)\r\n\t* [Token Classification (Named Entity Recognition)](#token-classification)\r\n\t* [Question/Answering (QA)](#questionanswering)\r\n* [Text-to-Speech (TTS)](#text-to-speech-tts)\r\n\t\r\nThe NLP models are using the [DistilBERT](https://arxiv.org/abs/1910.01108) transformer architecture for reduced memory usage and increased performance.  For samples of the text-to-speech output, see the [TTS Audio Samples](#tts-audio-samples) section below.\r\n\r\n## Running the Container\r\n\r\njetson-voice is distributed as a Docker container due to the number of dependencies.  There are pre-built containers images available on DockerHub for JetPack 4.4.1 and newer:\r\n\r\n```\r\ndustynv/jetson-voice:r32.4.4    # JetPack 4.4.1 (L4T R32.4.4)\r\ndustynv/jetson-voice:r32.5.0    # JetPack 4.5 (L4T R32.5.0) / JetPack 4.5.1 (L4T R32.5.1)\r\ndustynv/jetson-voice:r32.6.1    # JetPack 4.6 (L4T R32.6.1)\r\ndustynv/jetson-voice:r32.7.1    # JetPack 4.6.1 (L4T R32.7.1)\r\n```\r\n\r\nTo download and run the container, you can simply clone this repo and use the `docker/run.sh` script:\r\n\r\n``` bash\r\n$ git clone --branch dev https://github.com/dusty-nv/jetson-voice\r\n$ cd jetson-voice\r\n$ docker/run.sh\r\n```\r\n\r\n> **note**:  if you want to use a USB microphone or speaker, plug it in *before* you start the container\r\n\r\nThere are some optional arguments to `docker/run.sh` that you can use:\r\n\r\n* `-r` (`--run`) specifies a run command, otherwise the container will start in an interactive shell.\r\n* `-v` (`--volume`) mount a directory from the host into the container (`/host/path:/container/path`)\r\n* `--dev` starts the container in development mode, where all the source files are mounted for easy editing\r\n\r\nThe run script will automatically mount the `data/` directory into the container, which stores the models and other data files.  If you save files from the container there, they will also show up under `data/` on the host.\r\n\r\n## Automatic Speech Recognition (ASR)\r\n\r\nThe speech recognition in jetson-voice is a streaming service, so it's intended to be used on live sources and transcribes the audio in 1-second chunks.  It uses a [QuartzNet-15x5](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#quartznet) model followed by a CTC beamsearch decoder and language model, to further refine the raw output of the network.  It detects breaks in the audio to determine the end of sentences.  For information about using the ASR APIs, please refer to [`jetson_voice/asr.py`](jetson_voice/asr.py) and see [`examples/asr.py`](examples/asr.py)\r\n\r\nAfter you start the container, first run a test audio file (wav/ogg/flac) through [`examples/asr.py`](examples/asr.py) to verify that the system is functional.  Run this command (and all subsequent commands) inside the container:\r\n\r\n``` bash\r\n$ examples/asr.py --wav data/audio/dusty.wav\r\n\r\nhi\r\nhi hi this is dust\r\nhi hi this is dusty check\r\nhi hi this is dusty check one two\r\nhi hi this is dusty check one two three\r\nhi hi this is dusty check one two three.\r\n\r\nwhat's the weather or\r\nwhat's the weather going to be tomorrow\r\nwhat's the weather going to be tomorrow in pittsburgh\r\nwhat's the weather going to be tomorrow in pittsburgh.\r\n\r\ntoday is\r\ntoday is wednesday\r\ntoday is wednesday tomorrow is thursday\r\ntoday is wednesday tomorrow is thursday.\r\n\r\ni would like\r\ni would like to order a large\r\ni would like to order a large pepperoni pizza\r\ni would like to order a large pepperoni pizza.\r\n\r\nis it going to be\r\nis it going to be cloudy tomorrow.\r\n```\r\n\r\n> The first time you run each model, TensorRT will take a few minutes to optimize it.  \r\n> This optimized model is then cached to disk, so the next time you run the model it will load faster.\r\n\r\n#### Live Microphone\r\n\r\nTo test the ASR on a mic, first list the audio devices in your system to get the audio device ID's:\r\n\r\n``` bash\r\n$ scripts/list_audio_devices.sh\r\n\r\n----------------------------------------------------\r\n Audio Input Devices\r\n----------------------------------------------------\r\nInput Device ID 1 - 'tegra-snd-t210ref-mobile-rt565x: - (hw:1,0)' (inputs=16) (sample_rate=44100)\r\nInput Device ID 2 - 'tegra-snd-t210ref-mobile-rt565x: - (hw:1,1)' (inputs=16) (sample_rate=44100)\r\nInput Device ID 3 - 'tegra-snd-t210ref-mobile-rt565x: - (hw:1,2)' (inputs=16) (sample_rate=44100)\r\nInput Device ID 4 - 'tegra-snd-t210ref-mobile-rt565x: - (hw:1,3)' (inputs=16) (sample_rate=44100)\r\nInput Device ID 5 - 'tegra-snd-t210ref-mobile-rt565x: - (hw:1,4)' (inputs=16) (sample_rate=44100)\r\nInput Device ID 6 - 'tegra-snd-t210ref-mobile-rt565x: - (hw:1,5)' (inputs=16) (sample_rate=44100)\r\nInput Device ID 7 - 'tegra-snd-t210ref-mobile-rt565x: - (hw:1,6)' (inputs=16) (sample_rate=44100)\r\nInput Device ID 8 - 'tegra-snd-t210ref-mobile-rt565x: - (hw:1,7)' (inputs=16) (sample_rate=44100)\r\nInput Device ID 9 - 'tegra-snd-t210ref-mobile-rt565x: - (hw:1,8)' (inputs=16) (sample_rate=44100)\r\nInput Device ID 10 - 'tegra-snd-t210ref-mobile-rt565x: - (hw:1,9)' (inputs=16) (sample_rate=44100)\r\nInput Device ID 11 - 'Logitech H570e Mono: USB Audio (hw:2,0)' (inputs=2) (sample_rate=44100)\r\nInput Device ID 12 - 'Samson Meteor Mic: USB Audio (hw:3,0)' (inputs=2) (sample_rate=44100)\r\n```\r\n\r\n> If you don't see your audio device listed, exit and restart the container.  \r\n> USB devices should be attached *before* the container is started.\r\n\r\nThen run the ASR example with the `--mic <DEVICE>` option, and specify either the device ID or name:\r\n\r\n``` bash\r\n$ examples/asr.py --mic 11\r\n\r\nhey\r\nhey how are you guys\r\nhey how are you guys.\r\n\r\n# (Press Ctrl+C to exit)\r\n```\r\n\r\n## ASR Classification\r\n\r\nThere are other ASR models included for command/keyword recognition ([MatchboxNet](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/speech_classification/models.html#matchboxnet-speech-commands)) and voice activity detection ([VAD MarbleNet](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/speech_classification/models.html#marblenet-vad)).  These models are smaller and faster, and classify chunks of audio as opposed to transcribing text.  \r\n\r\n### Command/Keyword Recognition\r\n\r\nThe [MatchboxNet](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/speech_classification/models.html#matchboxnet-speech-commands) model was trained on 12 keywords from the [Google Speech Commands](https://ai.googleblog.com/2017/08/launching-speech-commands-dataset.html) dataset:\r\n\r\n```\r\n# MatchboxNet classes\r\n\"yes\",\r\n\"no\",\r\n\"up\",\r\n\"down\",\r\n\"left\",\r\n\"right\",\r\n\"on\",\r\n\"off\",\r\n\"stop\",\r\n\"go\",\r\n\"unknown\",\r\n\"silence\"\r\n```\r\n\r\nYou can run it through the same ASR example as above by specifying the `--model matchboxnet` argument:\r\n\r\n``` bash\r\n$ examples/asr.py --model matchboxnet --wav data/audio/commands.wav\r\n\r\nclass 'unknown' (0.384)\r\nclass 'yes' (1.000)\r\nclass 'no' (1.000)\r\nclass 'up' (1.000)\r\nclass 'down' (1.000)\r\nclass 'left' (1.000)\r\nclass 'left' (1.000)\r\nclass 'right' (1.000)\r\nclass 'on' (1.000)\r\nclass 'off' (1.000)\r\nclass 'stop' (1.000)\r\nclass 'go' (1.000)\r\nclass 'go' (1.000)\r\nclass 'silence' (0.639)\r\nclass 'silence' (0.576)\r\n```\r\n\r\nThe numbers printed on the right are the classification probabilities between 0 and 1.\r\n\r\n### Voice Activity Detection (VAD)\r\n\r\nThe voice activity model ([VAD MarbleNet](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/speech_classification/models.html#marblenet-vad)) is a binary model that outputs `background` or `speech`:\r\n\r\n``` bash\r\n$ examples/asr.py --model vad_marblenet --wav data/audio/commands.wav\r\n\r\nclass 'background' (0.969)\r\nclass 'background' (0.984)\r\nclass 'background' (0.987)\r\nclass 'speech' (0.997)\r\nclass 'speech' (1.000)\r\nclass 'speech' (1.000)\r\nclass 'speech' (0.998)\r\nclass 'background' (0.987)\r\nclass 'speech' (1.000)\r\nclass 'speech' (1.000)\r\nclass 'speech' (1.000)\r\nclass 'background' (0.988)\r\nclass 'background' (0.784)\r\n```\r\n\r\n## Natural Language Processing (NLP)\r\n\r\nThere are two samples included for NLP:\r\n\r\n* [`examples/nlp.py`](examples/nlp.py) (intent/slot, text classification, token classification)\r\n* [`examples/nlp_qa.py`](examples/nlp_qa.py) (question/answering)\r\n\r\nThese each use a [DistilBERT](https://arxiv.org/abs/1910.01108) model which has been fined-tuned for it's particular task.  For information about using the NLP APIs, please refer to [`jetson_voice/nlp.py`](jetson_voice/nlp.py) and see the samples above.\r\n\r\n### Joint Intent/Slot Classification\r\n\r\nJoint Intent and Slot classification is a task of classifying an Intent and detecting all relevant Slots (Entities) for this Intent in a query. For example, in the query: `What is the weather in Santa Clara tomorrow morning?`, we would like to classify the query as a `weather` Intent, and detect `Santa Clara` as a location slot and `tomorrow morning` as a date_time slot. \r\n\r\nIntents and Slots names are usually task specific and defined as labels in the training data.  The included intent/slot model was trained on the [NLU-Evaluation-Data](https://github.com/xliuhw/NLU-Evaluation-Data) dataset - you can find the various intent and slot classes that it supports [here](https://gist.github.com/dusty-nv/119474dfcf3bfccfbb8428951a64cd23).  They are common things that you might ask a virtual assistant:\r\n\r\n```\r\n$ examples/nlp.py --model distilbert_intent\r\n\r\nEnter intent_slot query, or Q to quit:\r\n\r\n> What is the weather in Santa Clara tomorrow morning?\r\n\r\n{'intent': 'weather_query',\r\n 'score': 0.7165476,\r\n 'slots': [{'score': 0.6280392, 'slot': 'place_name', 'text': 'Santa'},\r\n           {'score': 0.61760694, 'slot': 'place_name', 'text': 'Clara'},\r\n           {'score': 0.5439486, 'slot': 'date', 'text': 'tomorrow'},\r\n           {'score': 0.4520608, 'slot': 'date', 'text': 'morning'}]}\r\n\r\n> Set an alarm for 730am\r\n\r\n{'intent': 'alarm_set',\r\n 'score': 0.5713072,\r\n 'slots': [{'score': 0.40017933, 'slot': 'time', 'text': '730am'}]}\r\n\r\n> Turn up the volume\r\n\r\n{'intent': 'audio_volume_up', 'score': 0.33523008, 'slots': []}\r\n\r\n> What is my schedule for tomorrow?\r\n\r\n{'intent': 'calendar_query',\r\n 'score': 0.37434494,\r\n 'slots': [{'score': 0.5732627, 'slot': 'date', 'text': 'tomorrow'}]}\r\n\r\n> Order a pepperoni pizza from domino's\r\n\r\n{'intent': 'takeaway_order',\r\n 'score': 0.50629586,\r\n 'slots': [{'score': 0.27558547, 'slot': 'food_type', 'text': 'pepperoni'},\r\n           {'score': 0.2778827, 'slot': 'food_type', 'text': 'pizza'},\r\n           {'score': 0.21785143, 'slot': 'business_name', 'text': 'dominos'}]}\r\n\t\r\n> Where's the closest Starbucks?\r\n\r\n{'intent': 'recommendation_locations',\r\n 'score': 0.5438984,\r\n 'slots': [{'score': 0.1604197, 'slot': 'place_name', 'text': 'Starbucks'}]}\r\n\r\n```\r\n\r\n### Text Classification\r\n\r\nIn this text classification example, we'll use the included sentiment analysis model that was trained on the [Standford Sentiment Treebank (SST-2)](https://nlp.stanford.edu/sentiment/index.html) dataset.  It will label queries as either positive or negative, along with their probability:\r\n\r\n```\r\n$ examples/nlp.py --model distilbert_sentiment\r\n\r\nEnter text_classification query, or Q to quit:\r\n\r\n> today was warm, sunny and beautiful out\r\n\r\n{'class': 1, 'label': '1', 'score': 0.9985898}\r\n\r\n> today was cold and rainy and not very nice\r\n\r\n{'class': 0, 'label': '0', 'score': 0.99136007}\r\n```\r\n\r\n(class 0 is negative sentiment and class 1 is positive sentiment)\r\n\r\n### Token Classification\r\n\r\nWhereas text classification classifies entire queries, token classification classifies individual tokens (or words).  In this example, we'll be performing Named Entity Recognition (NER), which is the task of detecting and classifying key information (entities) in text. For example, in a sentence: `Mary lives in Santa Clara and works at NVIDIA`, we should detect that `Mary` is a person, `Santa Clara` is a location and `NVIDIA` is a company.\r\n\r\nThe included token classification model for NER was trained on the [Groningen Meaning Bank (GMB)](http://www.let.rug.nl/bjerva/gmb/about.php) and supports the following annotations in [IOB format](https://en.wikipedia.org/wiki/Inside%E2%80%93outside%E2%80%93beginning_(tagging)) (short for inside, outside, beginning)\r\n\r\n* LOC = Geographical Entity\r\n* ORG = Organization\r\n* PER = Person\r\n* GPE = Geopolitical Entity\r\n* TIME = Time indicator\r\n* MISC = Artifact, Event, or Natural Phenomenon\r\n\r\n``` bash\r\n$ examples/nlp.py --model distilbert_ner\r\n\r\nEnter token_classification query, or Q to quit:\r\n> Mary lives in Santa Clara and works at NVIDIA\r\n\r\nMary[B-PER 0.989] lives in Santa[B-LOC 0.998] Clara[I-LOC 0.996] and works at NVIDIA[B-ORG 0.967]\r\n\r\n> Lisa's favorite place to climb in the summer is El Capitan in Yosemite National Park in California, U.S.\r\n\r\nLisa's[B-PER 0.995] favorite place to climb in the summer[B-TIME 0.996] is El[B-PER 0.577] Capitan[I-PER 0.483] \r\nin Yosemite[B-LOC 0.987] National[I-LOC 0.988] Park[I-LOC 0.98] in California[B-LOC 0.998], U.S[B-LOC 0.997].\r\n```\r\n\r\n### Question/Answering\r\n\r\nQuestion/Answering (QA) works by supplying a context paragraph which the model then queries the best answer from.  The [`nlp_qa.py`](examples/nlp_qa.py) example allows you to select from several built-in context paragraphs (or supply your own) and to ask questions about these topics.  \r\n\r\nThe QA model is flexible and doesn't need re-trained on different topics, as it was trained on the [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/) question/answering dataset which allows it to extract answers from a variety of contexts.  It essentially learns to identify the information most relevant to your query from the context passage, as opposed to learning the content itself.\r\n\r\n``` bash\r\n$ examples/nlp_qa.py \r\n\r\nContext:\r\nThe Amazon rainforest is a moist broadleaf forest that covers most of the Amazon basin of South America. \r\nThis basin encompasses 7,000,000 square kilometres (2,700,000 sq mi), of which 5,500,000 square kilometres \r\n(2,100,000 sq mi) are covered by the rainforest. The majority of the forest is contained within Brazil, \r\nwith 60% of the rainforest, followed by Peru with 13%, and Colombia with 10%.\r\n\r\nEnter a question, C to change context, P to print context, or Q to quit:\r\n\r\n> How big is the Amazon?\r\n\r\nAnswer: 7,000,000 square kilometres\r\nScore:  0.24993503093719482\r\n\r\n> which country has the most?\r\n\r\nAnswer: Brazil\r\nScore:  0.5964332222938538\r\n```\r\n\r\nTo change the topic or create one of your own, enter `C`:\r\n\r\n```\r\nEnter a question, C to change context, P to print context, or Q to quit:\r\n> C\r\n\r\nSelect from one of the following topics, or enter your own context paragraph:\r\n   1. Amazon\r\n   2. Geology\r\n   3. Moon Landing\r\n   4. Pi\r\n   5. Super Bowl 55\r\n> 3\r\n\r\nContext:\r\nThe first manned Moon landing was Apollo 11 on July, 20 1969. The first human to step on the Moon was \r\nastronaut Neil Armstrong followed second by Buzz Aldrin. They landed in the Sea of Tranquility with their \r\nlunar module the Eagle. They were on the lunar surface for 2.25 hours and collected 50 pounds of moon rocks.\r\n\r\nEnter a question, C to change context, P to print context, or Q to quit:\r\n\r\n> Who was the first man on the moon?\r\n\r\nAnswer: Neil Armstrong\r\nScore:  0.39105066657066345\r\n```\r\n\r\n## Text-to-Speech (TTS)\r\n\r\nThe text-to-speech service uses an ensemble of two models:  FastPitch to generate MEL spectrograms from text, and HiFiGAN as the vocoder (female English voice).  For information about using the TTS APIs, please refer to [`jetson_voice/tts.py`](jetson_voice/tts.py) and see [`examples/tts.py`](examples/tts.py)\r\n\r\nThe [`examples/tts.py`](examples/tts.py) app can output the audio to a speaker, wav file, or sequence of wav files.  Run it with `--list-devices` to get a list of your audio devices.\r\n\r\n``` bash\r\n$ examples/tts.py --output-device 11 --output-wav data/audio/tts_test\r\n\r\n> The weather tomorrow is forecast to be warm and sunny with a high of 83 degrees.\r\n\r\nRun 0 -- Time to first audio: 1.820s. Generated 5.36s of audio. RTFx=2.95.\r\nRun 1 -- Time to first audio: 0.232s. Generated 5.36s of audio. RTFx=23.15.\r\nRun 2 -- Time to first audio: 0.230s. Generated 5.36s of audio. RTFx=23.31.\r\nRun 3 -- Time to first audio: 0.231s. Generated 5.36s of audio. RTFx=23.25.\r\nRun 4 -- Time to first audio: 0.230s. Generated 5.36s of audio. RTFx=23.36.\r\nRun 5 -- Time to first audio: 0.230s. Generated 5.36s of audio. RTFx=23.35.\r\n\r\nWrote audio to data/audio/tts_test/0.wav\r\n\r\nEnter text, or Q to quit:\r\n> Sally sells seashells by the seashore.\r\n\r\nRun 0 -- Time to first audio: 0.316s. Generated 2.73s of audio. RTFx=8.63.\r\nRun 1 -- Time to first audio: 0.126s. Generated 2.73s of audio. RTFx=21.61.\r\nRun 2 -- Time to first audio: 0.127s. Generated 2.73s of audio. RTFx=21.51.\r\nRun 3 -- Time to first audio: 0.126s. Generated 2.73s of audio. RTFx=21.68.\r\nRun 4 -- Time to first audio: 0.126s. Generated 2.73s of audio. RTFx=21.68.\r\nRun 5 -- Time to first audio: 0.126s. Generated 2.73s of audio. RTFx=21.61.\r\n\r\nWrote audio to data/audio/tts_test/1.wav\r\n```\r\n\r\n#### TTS Audio Samples\r\n\r\n* [Weather forecast](data/audio/tts_examples/0.wav) (wav)\r\n* [Sally sells seashells](data/audio/tts_examples/1.wav) (wav)\r\n\r\n\r\n## Tests\r\n\r\nThere is an automated test suite included that will verify all of the models are working properly.  You can run it with the `tests/run_tests.py` script:\r\n\r\n``` bash\r\n$ tests/run_tests.py\r\n\r\n----------------------------------------------------\r\n TEST SUMMARY\r\n----------------------------------------------------\r\ntest_asr.py (quartznet)                  PASSED\r\ntest_asr.py (quartznet_greedy)           PASSED\r\ntest_asr.py (matchboxnet)                PASSED\r\ntest_asr.py (vad_marblenet)              PASSED\r\ntest_nlp.py (distilbert_qa_128)          PASSED\r\ntest_nlp.py (distilbert_qa_384)          PASSED\r\ntest_nlp.py (distilbert_intent)          PASSED\r\ntest_nlp.py (distilbert_sentiment)       PASSED\r\ntest_nlp.py (distilbert_ner)             PASSED\r\ntest_tts.py (fastpitch_hifigan)          PASSED\r\n\r\npassed 10 of 10 tests\r\nsaved logs to data/tests/logs/20210610_1512\r\n```\r\n\r\nThe logs of the individual tests are printed to the screen and saved to a timestamped directory.\r\n\r\n\r\n\r\n"
  },
  {
    "path": "docker/build.sh",
    "content": "#!/usr/bin/env bash\n\nROS_DISTRO=${1:-\"none\"}\nBASE_IMAGE=$2\nNEMO_VERSION=\"1.0.0rc1\"\n\n# find container tag from os version\nsource docker/tag.sh\n\nif [ $ARCH = \"aarch64\" ]; then\n\tif [ -z $BASE_IMAGE ]; then\n\t\tif [ $L4T_VERSION = \"32.7.1\" ]; then\n\t\t\tBASE_IMAGE=\"l4t-ml:r32.7.1-py3\"\n\t\t\t#BASE_IMAGE=\"nvcr.io/nvidia/l4t-ml:r32.7.1-py3\"\n\t\t\tNEMO_VERSION=\"1.6.2\"\n\t\telif [ $L4T_VERSION = \"32.6.1\" ]; then\n\t\t\tBASE_IMAGE=\"nvcr.io/nvidia/l4t-ml:r32.6.1-py3\"\n\t\telif [ $L4T_VERSION = \"32.5.0\" ] || [ $L4T_VERSION = \"32.5.1\" ]; then\n\t\t\tBASE_IMAGE=\"nvcr.io/nvidia/l4t-ml:r32.5.0-py3\"\n\t\telif [ $L4T_VERSION = \"32.4.4\" ]; then\n\t\t\tBASE_IMAGE=\"nvcr.io/nvidia/l4t-ml:r32.4.4-py3\"\n\t\telif [ $L4T_VERSION = \"32.4.3\" ]; then\n\t\t\tBASE_IMAGE=\"nvcr.io/nvidia/l4t-ml:r32.4.3-py3\"\n\t\telif [ $L4T_VERSION = \"32.4.2\" ]; then\n\t\t\tBASE_IMAGE=\"nvcr.io/nvidia/l4t-ml:r32.4.2-py3\"\n\t\telse\n\t\t\techo \"cannot build jetson-voice docker container for L4T R$L4T_VERSION\"\n\t\t\techo \"please upgrade to the latest JetPack, or build jetson-voice natively\"\n\t\t\texit 1\n\t\tfi\n\tfi\nelif [ $ARCH = \"x86_64\" ]; then\n\tBASE_IMAGE=${BASE_IMAGE:-\"nvcr.io/nvidia/nemo:$NEMO_VERSION\"}\nfi\n\nVOICE_CONTAINER=\"$CONTAINER_NAME:$TAG\"\nVOICE_CONTAINER_BASE=\"$VOICE_CONTAINER-base\"\n\n# build the base container\necho \"CONTAINER=$VOICE_CONTAINER_BASE\"\necho \"BASE_IMAGE=$BASE_IMAGE\"\n\nsudo docker build -t $VOICE_CONTAINER_BASE -f Dockerfile.$ARCH \\\n          --build-arg BASE_IMAGE=$BASE_IMAGE \\\n\t\t--build-arg NEMO_VERSION=$NEMO_VERSION \\\n\t\t.\n\t\t\n# build the runtime container\necho \"CONTAINER=$VOICE_CONTAINER\"\necho \"BASE_IMAGE=$VOICE_CONTAINER_BASE\"\n\nsudo docker build -t $VOICE_CONTAINER -f Dockerfile.runtime \\\n          --build-arg BASE_IMAGE=$VOICE_CONTAINER_BASE \\\n\t\t.\n\n# build ROS version of container\nif [[ \"$ROS_DISTRO\" != \"none\" ]] && [[ $ARCH = \"aarch64\" ]]; then\n\tROS_CONTAINER=\"$VOICE_CONTAINER-ros-$ROS_DISTRO\"\n\tROS_CONTAINER_BASE=\"$ROS_CONTAINER-base\"\n\t\n\t# copy files needed to build ROS container\n\tif [ ! -d \"packages/\" ]; then\n\t\tcp -r docker/containers/packages packages\n\tfi\n\t\n\t# opencv.csv mounts files that preclude us installing different version of opencv\n\t# temporarily disable the opencv.csv mounts while we build the container\n\tCV_CSV=\"/etc/nvidia-container-runtime/host-files-for-container.d/opencv.csv\"\n\n\tif [ -f \"$CV_CSV\" ]; then\n\t\tsudo mv $CV_CSV $CV_CSV.backup\n\tfi\n\t\n\t# build ROS on top of jetson-voice \n\techo \"CONTAINER=$ROS_CONTAINER_BASE\"\n\techo \"BASE_IMAGE=$VOICE_CONTAINER_BASE\"\n\n\tsudo docker build -t $ROS_CONTAINER_BASE -f docker/containers/Dockerfile.ros.$ROS_DISTRO \\\n          --build-arg BASE_IMAGE=$VOICE_CONTAINER_BASE \\\n\t\t.\n\t\n\t# install jetson_voice_ros package\n\techo \"CONTAINER=$ROS_CONTAINER\"\n\techo \"BASE_IMAGE=$ROS_CONTAINER_BASE\"\n\n\tsudo docker build -t $ROS_CONTAINER -f Dockerfile.ros \\\n          --build-arg BASE_IMAGE=$ROS_CONTAINER_BASE \\\n\t\t.\n\t\t\n\t# restore opencv.csv mounts\n\tif [ -f \"$CV_CSV.backup\" ]; then\n\t\tsudo mv $CV_CSV.backup $CV_CSV\n\tfi\nfi"
  },
  {
    "path": "docker/push.sh",
    "content": "#!/usr/bin/env bash\n\nROS_DISTRO=${1:-\"foxy\"}\nsource docker/tag.sh\n\n# push image\npush() \n{\n\tlocal remote_image=\"dustynv/$1\"\n\t\n\tsudo docker rmi $remote_image\n\tsudo docker tag $1 $remote_image\n\t\n\techo \"pushing image $remote_image\"\n\tsudo docker push $remote_image\n\techo \"done pushing image $remote_image\"\n}\n\npush \"$CONTAINER_NAME:$TAG\"\n\nROS_CONTAINER=\"$CONTAINER_NAME:$TAG-ros-$ROS_DISTRO\"\npush \"$ROS_CONTAINER\""
  },
  {
    "path": "docker/run.sh",
    "content": "#!/usr/bin/env bash\n#\n# Start an instance of the jetson-voice docker container.\n# See below or run this script with -h or --help to see usage options.\n#\n# This script should be run from the root dir of the jetson-voice project:\n#\n#     $ cd /path/to/your/jetson-voice\n#     $ docker/run.sh\n#\n\nshow_help() {\n    echo \" \"\n    echo \"usage: Starts the Docker container and runs a user-specified command\"\n    echo \" \"\n    echo \"   ./docker/run.sh --container DOCKER_IMAGE\"\n    echo \"                   --volume HOST_DIR:MOUNT_DIR\"\n    echo \"                   --run RUN_COMMAND\"\n    echo \" \"\n    echo \"args:\"\n    echo \" \"\n    echo \"   --help                       Show this help text and quit\"\n    echo \" \"\n    echo \"   -c, --container DOCKER_IMAGE Specifies the name of the Docker container\"\n    echo \"                                image to use (default: 'jetson-voice')\"\n    echo \" \"\n    echo \"   --ros ROS_DISTRO Starts the version of the container using the\"\n    echo \"                    specified ROS distro (or foxy if not specified)\"\n    echo \"                    This is overridden by the --container argument\"\n    echo \" \"\n    echo \"   -d, --dev  Runs the container in development mode, where the source\"\n    echo \"              files are mounted into the container dynamically, so they\"\n    echo \"              can more easily be edited from the host machine.\"\n    echo \" \"\n    echo \"   -v, --volume HOST_DIR:MOUNT_DIR Mount a path from the host system into\"\n    echo \"                                   the container.  Should be specified as:\"\n    echo \" \"\n    echo \"                                      -v /my/host/path:/my/container/path\"\n    echo \" \"\n    echo \"                                   (these should be absolute paths)\"\n    echo \" \"\n    echo \"   -r, --run RUN_COMMAND  Command to run once the container is started.\"\n    echo \"                          Note that this argument must be invoked last,\"\n    echo \"                          as all further arguments will form the command.\"\n    echo \"                          If no run command is specified, an interactive\"\n    echo \"                          terminal into the container will be provided.\"\n    echo \" \"\n}\n\ndie() {\n    printf '%s\\n' \"$1\"\n    show_help\n    exit 1\n}\n\n# find container tag from os version\nsource docker/tag.sh\n\n# where the project resides inside docker\nDOCKER_ROOT=\"/jetson-voice\"\t\n\n# generate mount commands\nDATA_VOLUME=\"--volume $PWD/data:$DOCKER_ROOT/data\"\nDEV_VOLUME=\"\"\n\n# parse user arguments\nUSER_VOLUME=\"\"\nUSER_COMMAND=\"\"\n\nwhile :; do\n    case $1 in\n        -h|-\\?|--help)\n            show_help    # Display a usage synopsis.\n            exit\n            ;;\n        -c|--container)       # Takes an option argument; ensure it has been specified.\n            if [ \"$2\" ]; then\n                CONTAINER_IMAGE=$2\n                shift\n            else\n                die 'ERROR: \"--container\" requires a non-empty option argument.'\n            fi\n            ;;\n        --container=?*)\n            CONTAINER_IMAGE=${1#*=} # Delete everything up to \"=\" and assign the remainder.\n            ;;\n        --container=)         # Handle the case of an empty --image=\n            die 'ERROR: \"--container\" requires a non-empty option argument.'\n            ;;\n\t   --ros)\n            if [ \"$2\" ]; then\n                ROS_DISTRO=$2\n                shift\n            else\n                ROS_DISTRO=\"foxy\"\n            fi\n            ;;\n        --ros=?*)\n            ROS_DISTRO=${1#*=} # Delete everything up to \"=\" and assign the remainder.\n            ;;\n        --ros=)         # Handle the case of an empty --image=\n            ROS_DISTRO=\"foxy\"\n            ;;\n\t   -d|--dev)\n            DEV_VOLUME=\"--volume $PWD/jetson_voice:$DOCKER_ROOT/jetson_voice --volume $PWD/examples:$DOCKER_ROOT/examples --volume $PWD/scripts:$DOCKER_ROOT/scripts --volume $PWD/tests:$DOCKER_ROOT/tests\"\n            ;;\n        -v|--volume)\n            if [ \"$2\" ]; then\n                USER_VOLUME=\" -v $2 \"\n                shift\n            else\n                die 'ERROR: \"--volume\" requires a non-empty option argument.'\n            fi\n            ;;\n        --volume=?*)\n            USER_VOLUME=\" -v ${1#*=} \" # Delete everything up to \"=\" and assign the remainder.\n            ;;\n        --volume=)         # Handle the case of an empty --image=\n            die 'ERROR: \"--volume\" requires a non-empty option argument.'\n            ;;\n        -r|--run)\n            if [ \"$2\" ]; then\n                shift\n                USER_COMMAND=\" $@ \"\n            else\n                die 'ERROR: \"--run\" requires a non-empty option argument.'\n            fi\n            ;;\n        --)              # End of all options.\n            shift\n            break\n            ;;\n        -?*)\n            printf 'WARN: Unknown option (ignored): %s\\n' \"$1\" >&2\n            ;;\n        *)               # Default case: No more options, so break out of the loop.\n            break\n    esac\n\n    shift\ndone\n\n# select the container, unless --container was explicitly specified\nif [ -z \"$CONTAINER_IMAGE\" ]; then\n\tCONTAINER_IMAGE=\"$CONTAINER_NAME:$TAG\"\n\n\tif [ -n \"$ROS_DISTRO\" ]; then\n\t\tCONTAINER_IMAGE=\"$CONTAINER_NAME:$TAG-ros-$ROS_DISTRO\"\n\tfi\n\n\tCONTAINER_REMOTE_IMAGE=\"dustynv/$CONTAINER_IMAGE\"\n\n\t# check for local image\n\tif [[ \"$(sudo docker images -q $CONTAINER_IMAGE 2> /dev/null)\" == \"\" ]]; then\n\t\tCONTAINER_IMAGE=$CONTAINER_REMOTE_IMAGE\n\tfi\nfi\n\necho \"CONTAINER:     $CONTAINER_IMAGE\"\necho \"DEV_VOLUME:    $DEV_VOLUME\"\necho \"DATA_VOLUME:   $DATA_VOLUME\"\necho \"USER_VOLUME:   $USER_VOLUME\"\necho \"USER_COMMAND:  $USER_COMMAND\"\n\nMOUNTS=\"\\\n--device /dev/snd \\\n--device /dev/bus/usb \\\n--volume /etc/timezone:/etc/timezone:ro \\\n--volume /etc/localtime:/etc/localtime:ro \\\n$DEV_VOLUME \\\n$DATA_VOLUME \\\n$USER_VOLUME\"\n\nif [ $ARCH = \"aarch64\" ]; then\n\n\tsudo docker run --runtime nvidia -it --rm \\\n\t\t--name=$CONTAINER_NAME \\\n\t\t--network host \\\n\t\t$MOUNTS $CONTAINER_IMAGE $USER_COMMAND\n\t    \nelif [ $ARCH = \"x86_64\" ]; then\n\n\tsudo docker run --gpus all -it --rm \\\n\t\t--name=$CONTAINER_NAME \\\n\t\t--network=host \\\n\t\t--shm-size=8g \\\n\t\t--ulimit memlock=-1 \\\n\t\t--ulimit stack=67108864 \\\n\t\t$MOUNTS $CONTAINER_IMAGE $USER_COMMAND\n\t\t\nfi\n"
  },
  {
    "path": "docker/tag.sh",
    "content": "#!/usr/bin/env bash\n\n# find OS version\nsource scripts/os_version.sh\n\nif [ $ARCH = \"aarch64\" ]; then\n\tTAG=\"r$L4T_VERSION\"\n\t\n\tif [ $L4T_VERSION = \"32.5.1\" ] || [ $L4T_VERSION = \"32.5.2\" ]; then\n\t\tTAG=\"r32.5.0\"\n\tfi\t\nelif [ $ARCH = \"x86_64\" ]; then\n\tTAG=\"$ARCH\"\nelse\n\techo \"unsupported architecture:  $ARCH\"\n\texit 1\nfi\n\nCONTAINER_NAME=\"jetson-voice\"\n\n\n"
  },
  {
    "path": "examples/asr.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nimport sys\n\nfrom jetson_voice import ASR, AudioInput, ConfigArgParser, list_audio_devices\n    \n    \nparser = ConfigArgParser()\n\nparser.add_argument('--model', default='quartznet', type=str, help='path to model, service name, or json config file')\nparser.add_argument('--wav', default=None, type=str, help='path to input wav/ogg/flac file')\nparser.add_argument('--mic', default=None, type=str, help='device name or number of input microphone')\nparser.add_argument('--list-devices', action='store_true', help='list audio input devices')\n\nargs = parser.parse_args()\nprint(args)\n    \n# list audio devices\nif args.list_devices:\n    list_audio_devices()\n    sys.exit()\n    \n# load the model\nasr = ASR(args.model)\n\n# create the audio input stream\nstream = AudioInput(wav=args.wav, mic=args.mic, \n                     sample_rate=asr.sample_rate, \n                     chunk_size=asr.chunk_size)\n\n# run transcription\nfor samples in stream:\n    results = asr(samples)\n    \n    if asr.classification:\n        print(f\"class '{results[0]}' ({results[1]:.3f})\")\n    else:\n        for transcript in results:\n            print(transcript['text'])\n            \n            if transcript['end']:\n                print('')\n                \nprint('\\naudio stream closed.')\n    "
  },
  {
    "path": "examples/assistant.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nimport sys\nimport pprint\n\nfrom jetson_voice import (\n    ASR, NLP, TTS, \n    AudioInput, AudioOutput, list_audio_devices,\n    ConfigArgParser\n)\n       \nparser = ConfigArgParser()\n\nparser.add_argument('--asr-model', default='quartznet', type=str, help='ASR model')\nparser.add_argument('--nlp-model', default='distilbert_intent', type=str, help='NLP model')\nparser.add_argument('--tts-model', default='fastpitch_hifigan', type=str, help='TTS model')\nparser.add_argument('--wav', default=None, type=str, help='path to input wav/ogg/flac file')\nparser.add_argument('--mic', default=None, type=str, help='device name or number of input microphone')\nparser.add_argument('--output-device', default=None, type=str, help='device name or number of audio output')\nparser.add_argument('--list-devices', action='store_true', help='list audio input devices')\n\nargs = parser.parse_args()\nprint(args)\n    \n# list audio devices\nif args.list_devices:\n    list_audio_devices()\n    sys.exit()\n    \n# load the models\ntts = TTS(args.tts_model)\nasr = ASR(args.asr_model, add_punctuation=False)\nnlp = NLP(args.nlp_model)\n\nif asr.classification:\n    raise ValueError(f\"'{args.asr_model}' is a classification model - must use a transcription model for agent\")\n\nif nlp.config.type != 'intent_slot':\n    raise ValueError(f\"'{args.nlp_model}' has type '{nlp.config.type}' - the agent requires an intent_slot model\")\n    \n# create the audio streams\naudio_input = AudioInput(wav=args.wav, mic=args.mic, \n                         sample_rate=asr.sample_rate, \n                         chunk_size=asr.chunk_size)\n\naudio_output = AudioOutput(device=args.output_device,\n                           sample_rate=tts.sample_rate)\n\n\ndef get_slot(results, name, default='', threshold=0, merge=True):\n    \"\"\"\n    Retrieve a slot by name from the intent/slot results.\n    The name can be a list of names, and any of them will be matched.\n    Only slots with a score above the threshold will be returned.\n    If merge is true, all slots by that name will be combined.\n    If merge is false, the first matching slot will be returned.\n    \"\"\"\n    if isinstance(name, str):\n        name = [name]\n        \n    slots = []\n\n    for slot in results['slots']:\n        if any(slot['slot'] == n for n in name) and slot['score'] >= threshold:\n            slots.append(slot['text'])\n            \n    if len(slots) == 0:\n        return default\n        \n    if len(slots) > 1 and merge:\n        return ' '.join(slots)\n        \n    return slots[0]\n      \n      \ndef generate_response(query):\n    results = nlp(query)\n    pprint.pprint(results)\n    \n    intent = results['intent']\n    \n    if intent == 'general_praise':\n        return \"Why thank you very much!\"\n        \n    elif intent == 'weather_query':\n        place = get_slot(results, 'place_name')\n        date = get_slot(results, 'date')\n        \n        response = \"The weather \"\n        \n        if place: response += 'in ' + place + ' '\n        if date:  response += date + ' '\n        \n        return response + \"is forecast to be sunny with a high of 78 degrees.\"\n        \n    elif intent == 'recommendation_locations':\n        place = get_slot(results, ['place_name', 'business_name'])\n        \n        if not place:\n            return \"Please ask again with the name of a store or restaurant.\"\n          \n        return f\"{place} is located 1 mile away at 1 2 3 Main Street.\"\n        \n    return \"I'm sorry, I don't understand.\"\n    \n# run agent\nfor input_samples in audio_input:\n    transcripts = asr(input_samples)\n\n    for transcript in transcripts:\n        print(transcript['text'])\n        \n        if not transcript['end']:\n            continue\n            \n        print('')\n        \n        response = generate_response(transcript['text'])\n        print(response)\n        \n        audio_output.write(tts(response))\n\n    \"\"\"\n    if transcripts[0] != 'unknown' and transcripts[1] != 'silence':\n        response = generate_response(transcripts[0])\n        print(response)\n        \n        audio_output.write(tts(response))\n    \"\"\""
  },
  {
    "path": "examples/nlp.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nimport sys\nimport pprint\nimport readline\n\nfrom jetson_voice import NLP, ConfigArgParser\n\n\nparser = ConfigArgParser()\nparser.add_argument('--model', default='distilbert_sentiment', type=str)\nargs = parser.parse_args()\nprint(args)\n\n# load the model\nmodel = NLP(args.model)\n\n# QA models should run the nlp_qa.py example\ntype = model.config.type\n\nif type == 'qa':\n    raise ValueError(\"please run Question/Answer models with the nlp_qa.py sample\")\n\n\nwhile True:\n    print(f'\\nEnter {type} query, or Q to quit:')\n    query = input('> ')\n    \n    if query.upper() == 'Q':\n        sys.exit()\n    \n    print('')\n    \n    results = model(query)\n        \n    if type == 'intent_slot' or type == 'text_classification':\n        pprint.pprint(results)\n    \n    elif type == 'token_classification':\n        print(f'{model.tag_string(query, results, scores=True)}')\n        "
  },
  {
    "path": "examples/nlp_qa.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nimport sys\nimport readline\n\nfrom jetson_voice import QuestionAnswer, ConfigArgParser\n\nparser = ConfigArgParser()\nparser.add_argument('--model', default='distilbert_qa_384', type=str)\nparser.add_argument('--top_k', default=1, type=int, help='show the top N answers (default 1)')\nargs = parser.parse_args()\nprint(args)\n\nmodel = QuestionAnswer(args.model)  # load the QA model\n\nbuiltin_context = {\n    \"Amazon\" : \"The Amazon rainforest is a moist broadleaf forest that covers most of the Amazon basin of South America. \"\n               \"This basin encompasses 7,000,000 square kilometres (2,700,000 sq mi), of which 5,500,000 square kilometres \"\n               \"(2,100,000 sq mi) are covered by the rainforest. The majority of the forest is contained within Brazil, \"\n               \"with 60% of the rainforest, followed by Peru with 13%, and Colombia with 10%.\",\n    \n    \"Geology\" : \"There are three major types of rock: igneous, sedimentary, and metamorphic. Igneous rocks are formed from \"\n                \"melted rock deep inside the Earth. Sedimentary rocks are compressed layers of sand, silt, dead plants, and \"\n                \"animal skeletons. Metamorphic rocks are other rocks that are changed by heat and pressure underground.\",\n    \n    \"Moon Landing\" : \"The first manned Moon landing was Apollo 11 on July, 20 1969. The first human to step on the Moon was \"\n                     \"astronaut Neil Armstrong followed second by Buzz Aldrin. They landed in the Sea of Tranquility with their \"\n                     \"lunar module the Eagle. They were on the lunar surface for 2.25 hours and collected 50 pounds of moon rocks.\",\n           \n    \"Pi\" : \"Some people have said that Pi is tasty but there should be a value for Pi, and the value for Pi is around 3.14. \"\n           \"Pi is the ratio of a circle's circumference to it's diameter. The constant Pi was first calculated by Archimedes \"\n           \"in ancient Greece around the year 250 BC.\",\n           \n    \"Super Bowl 55\" : \"Super Bowl 55 took place on February 7, 2021 in Tampa, Florida between the Kansas City Chiefs and \"\n                      \"the Tampa Bay Buccaneers.  The Tampa Bay Buccaneers won by a score of 31 to 9. In his first season \"\n                      \"with Tampa Bay, it was quarterback Tom Brady's seventh Super Bowl win in nine appearances.\",\n}\n\ncontext = builtin_context['Amazon']\n\ndef print_context():\n    print('\\nContext:')\n    print(context)\n    \ndef parse_commands(entry):\n    \"\"\"\n    Parse 'C' command for changing context, 'P' to print context, and 'Q' for quit.\n    Returns true if a command was entered, otherwise false.\n    \"\"\"\n    global context\n\n    if entry == 'C':\n        print('\\nSelect from one of the following topics, or enter your own context paragraph:')\n        for idx, key in enumerate(builtin_context):\n            print(f'   {idx+1}. {key}')\n        entry = input('> ')\n        try:  # try parsing as a number\n            num = int(entry)\n            if num > 0 and num <= len(builtin_context):\n                context = builtin_context[list(builtin_context.keys())[num-1]]\n            else:\n                print('Invalid entry')\n        except:  # try looking up topic name, otherwise custom paragraph\n            if entry in builtin_context:\n                context = builtin_context[entry.lower()]\n            else:\n                context = entry\n                \n        print_context()\n        return True\n        \n    elif entry == 'P':\n        print_context()\n        return True\n    elif entry == 'Q':\n        sys.exit()\n        \n    return False\n    \nprint_context()\n\nwhile True:\n    print('\\nEnter a question, C to change context, P to print context, or Q to quit:')\n    entry = input('> ')\n    \n    if parse_commands(entry.upper()):\n        continue\n    \n    query = {\n        'context' : context,\n        'question' : entry\n    }\n    \n    results = model(query, top_k=args.top_k)\n    \n    if args.top_k == 1:\n        results = [results]\n        \n    for result in results:\n        print('\\nAnswer:', result['answer'])\n        print('Score: ', result['score'])\n        "
  },
  {
    "path": "examples/tts.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nimport os\nimport sys\nimport time\nimport readline\n\nfrom jetson_voice import TTS, ConfigArgParser, AudioOutput, list_audio_devices\nfrom soundfile import SoundFile\n\n\nparser = ConfigArgParser()\n\nparser.add_argument('--model', default='fastpitch_hifigan', type=str)\nparser.add_argument('--warmup', default=5, type=int, help='the number of warmup runs')\nparser.add_argument(\"--output-device\", default=None, type=str, help='output audio device to use')\nparser.add_argument(\"--output-wav\", default=None, type=str, help='output directory or wav file to write to')\nparser.add_argument('--list-devices', action='store_true', help='list audio input devices')\n\nargs = parser.parse_args()\nprint(args)\n\n# list audio devices\nif args.list_devices:\n    list_audio_devices()\n    sys.exit()\n    \n# load the model\ntts = TTS(args.model)\n\n# open output audio device\nif args.output_device:\n    audio_device = AudioOutput(args.output_device, tts.sample_rate)\n\n# create output wav directory\nif args.output_wav:\n    wav_is_dir = len(os.path.splitext(args.output_wav)[1]) == 0\n    wav_count = 0\n    if wav_is_dir and not os.path.exists(args.output_wav):\n        os.makedirs(args.output_wav)\n\n\nwhile True:\n    print(f'\\nEnter text, or Q to quit:')\n    text = input('> ')\n    \n    if text.upper() == 'Q':\n        sys.exit()\n    \n    print('')\n    \n    # run the TTS\n    for run in range(args.warmup+1):\n        start = time.perf_counter()\n        audio = tts(text)\n        stop = time.perf_counter()\n        latency = stop-start\n        duration = audio.shape[0]/tts.sample_rate\n        print(f\"Run {run} -- Time to first audio: {latency:.3f}s. Generated {duration:.2f}s of audio. RTFx={duration/latency:.2f}.\")\n        \n    # output the audio\n    if args.output_device:\n        audio_device.write(audio)\n    \n    if args.output_wav:\n        wav_path = os.path.join(args.output_wav, f'{wav_count}.wav') if wav_is_dir else args.output_wav\n        wav = SoundFile(wav_path, mode='w', samplerate=tts.sample_rate, channels=1)\n        wav.write(audio)\n        wav.close()\n        wav_count += 1\n        print(f\"\\nWrote audio to {wav_path}\")\n\n    "
  },
  {
    "path": "jetson_voice/__init__.py",
    "content": "#!/usr/bin/env python3\r\n# coding: utf-8\r\n\r\nfrom .utils import (\r\n    find_resource, list_models, global_config, ConfigDict, ConfigArgParser,\r\n    list_audio_devices, list_audio_inputs, list_audio_outputs, AudioInput, AudioOutput \r\n)\r\n\r\nfrom .asr import ASR, ASRService\r\nfrom .tts import TTS, TTSService\r\n\r\nfrom .nlp import (NLP,\r\n    IntentSlot, IntentSlotService, \r\n    QuestionAnswer, QuestionAnswerService,\r\n    TextClassification, TextClassificationService,\r\n    TokenClassification, TokenClassificationService,\r\n)\r\n\r\nfrom .auto import AutoModel\r\n\r\n__version__ = global_config.version\r\n"
  },
  {
    "path": "jetson_voice/asr.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nfrom jetson_voice.utils import load_resource\n\n\ndef ASR(resource, *args, **kwargs):\n    \"\"\"\n    Loads a streaming ASR service or model.\n    See the ASRService class for the signature that implementations use.\n    \"\"\"\n    factory_map = {\n        'riva' : 'jetson_voice.backends.riva.RivaASRService',\n        'tensorrt' : 'jetson_voice.models.asr.ASREngine',\n        'onnxruntime' : 'jetson_voice.models.asr.ASREngine'\n    }\n    \n    return load_resource(resource, factory_map, *args, **kwargs)\n\n    \nclass ASRService():\n    \"\"\"\n    Streaming ASR service base class.\n    \"\"\"\n    def __init__(self, config, *args, **kwargs):\n        self.config = config\n        \n    def __call__(self, samples):\n        \"\"\"\n        Transcribe streaming audio samples to text, returning the running phrase.\n        Phrases are broken up when a break in the audio is detected (i.e. end of sentence)\n        \n        Parameters:\n          samples (array) -- Numpy array of audio samples.\n\n        Returns a list[dict] of the running transcripts with the following keys:\n        \n          text (string) -- the transcript of the current sentence\n          words (list[dict]) -- a list of word dicts that make up the sentence\n          end (bool) -- if true, end-of-sentence due to silence\n          \n        Each transcript represents one phrase/sentence.  When a sentence has been determined\n        to be ended, it will be marked with end=True.  Multiple sentence transcripts can be \n        returned if one just ended and another is beginning. \n        \"\"\"\n        pass\n    \n    @property\n    def classification(self):\n        \"\"\"\n        Returns true if this is an ASR classification model (e.g. for VAD or keyword spotting)\n        Otherwise, this is an ASR transcription model that converts audio to text.\n        \"\"\"\n        return False\n        \n    @property\n    def sample_rate(self):\n        \"\"\"\n        The sample rate that the model runs at (in Hz)\n        Input audio should be resampled to this rate.\n        \"\"\"\n        pass\n    \n    @property\n    def frame_length(self):\n        \"\"\"\n        Duration in seconds per frame / chunk.\n        \"\"\"\n        pass\n        \n    @property\n    def chunk_size(self):\n        \"\"\"\n        Number of samples per frame/chunk (equal to frame_length * sample_rate)\n        \"\"\"\n        pass\n        \n        \nif __name__ == \"__main__\":\n\n    from jetson_voice import list_audio_devices, AudioInput, ConfigArgParser\n    import sys\n    \n    parser = ConfigArgParser()\n    \n    parser.add_argument('--model', default='quartznet', type=str, help='path to model, service name, or json config file')\n    parser.add_argument('--wav', default=None, type=str, help='path to input wav file')\n    parser.add_argument('--mic', default=None, type=str, help='device name or number of input microphone')\n    parser.add_argument('--list-devices', action='store_true', help='list audio input devices')\n    \n    args = parser.parse_args()\n    print(args)\n    \n    # list audio devices\n    if args.list_devices:\n        list_audio_devices()\n        sys.exit()\n        \n    # load the model\n    asr = ASR(args.model)\n    \n    # create the audio input stream\n    stream = AudioInput(wav=args.wav, mic=args.mic, \n                         sample_rate=asr.sample_rate, \n                         chunk_size=asr.chunk_size)\n    \n    # run transcription\n    for samples in stream:\n        #samples = audio_to_float(samples)\n        #print(f'samples {samples.shape} ({audio_db(samples):.1f} dB)')\n        results = asr(samples)\n        \n        if asr.classification:\n            print(f\"class '{results[0]}' ({results[1]:.3f})\")\n        else:\n            for transcript in results:\n                print(transcript['text'])\n                \n                if transcript['end']:\n                    print('')\n                    \n    print('\\naudio stream closed.')\n    "
  },
  {
    "path": "jetson_voice/auto.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nfrom jetson_voice.asr import ASR\nfrom jetson_voice.nlp import IntentSlot, QuestionAnswer, TextClassification, TokenClassification\nfrom jetson_voice.tts import TTS\n\nfrom jetson_voice.utils import load_resource\n\n\ndef AutoModel(resource, domain=None, *args, **kwargs):\n    \"\"\"\n    Factory for automatically loading models and services.\n    First the config is loaded and the type is checked.\n    Then the correct instance for the resource is created.\n    \n    If a domain string is supplied (e.g. 'asr', 'nlp', 'tts'),\n    then only resources from that domain will be created.\n    \"\"\"\n    type_map = {\n        # models\n        'asr' : (ASR, 'asr'),\n        'asr_classification' : (ASR, 'asr'),\n        'intent_slot' : (IntentSlot, 'nlp'),\n        'qa' : (QuestionAnswer, 'nlp'),\n        'text_classification' : (TextClassification, 'nlp'),\n        'token_classification' : (TokenClassification, 'nlp'),\n        'tts': (TTS, 'tts'),\n        \n        # services\n        'jarvis_asr' : (ASR, 'asr')\n    }\n\n    config = load_resource(resource, None, *args, **kwargs)\n    \n    if 'type' not in config:\n        raise ValueError(f\"'type' setting missing from config '{config.path}'\")\n        \n    if config.type not in type_map:\n        raise ValueError(f\"'{config.path}' has invalid 'type' ({config.type})\")\n    \n    if domain:\n        if type_map[config.type][1] != domain.lower():\n            raise ValueError(f\"invalid model selected - '{config.path}' has domain '{type_map[config.type][1]}', but AutoModel() was called with domain={domain}\")\n            \n    return type_map[config.type][0](config, *args, **kwargs)\n"
  },
  {
    "path": "jetson_voice/backends/onnxruntime/__init__.py",
    "content": "#!/usr/bin/env python3\r\n# coding: utf-8\r\n\r\nfrom .ort_model import OnnxRuntimeModel\r\n\r\n"
  },
  {
    "path": "jetson_voice/backends/onnxruntime/ort_model.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nimport os\nimport logging\n\n# for some reason if PyCUDA isn't initialized before OnnxRuntime\n# and TensorRT is also used, it makes TensorRT error\nimport pycuda.driver as cuda\nimport pycuda.autoinit\n\nimport numpy as np\nimport onnxruntime as ort\n\n\nclass OnnxRuntimeModel:\n    \"\"\"\n    Base class for OnnxRuntime models.\n    \"\"\"\n    def __init__(self, config, *args, **kwargs):\n        \"\"\"\n        Load an ONNX Runtime model.\n        \"\"\"\n        self.config = config\n        \n        logging.info(f\"loading ONNX model '{self.config.model_path}' with onnxruntime\")\n        self.model = ort.InferenceSession(config.model_path, providers=['CUDAExecutionProvider'])\n        logging.info(f\"loaded ONNX model '{self.config.model_path}' with onnxruntime\")\n        \n        self.inputs = self.model.get_inputs()\n        self.outputs = self.model.get_outputs()\n        \n        for idx, binding in enumerate(self.inputs):\n            print('')\n            print(f\"input {idx} - {binding.name}\")\n            print(f\"   shape: {binding.shape}\")\n            print(f\"   type:  {binding.type}\")\n            print('')\n \n    def execute(self, inputs, return_dict=False, **kwargs):\n        \"\"\"\n        Run the DNN model in TensorRT.  The inputs are provided as numpy arrays in a list/tuple/dict.\n        Note that run() doesn't perform any pre/post-processing - this is typically done in subclasses.\n        \n        Parameters:\n          inputs (array, list[array], dict[array]) -- the network inputs as numpy array(s).\n                         If there is only one input, it can be provided as a single numpy array.\n                         If there are multiple inputs, they can be provided as numpy arrays in a\n                         list, tuple, or dict.  Inputs in lists and tuples are assumed to be in the\n                         same order as the input bindings.  Inputs in dicts should have keys with the\n                         same names as the input bindings.\n          return_dict (bool) -- If True, the results will be returned in a dict of numpy arrays, where the\n                                keys are the names of the output binding names. By default, the results will \n                                be returned in a list of numpy arrays, in the same order as the output bindings.\n          \n        Returns the model output as a numpy array (if only one output), list[ndarray], or dict[ndarray].\n        \"\"\"\n        if isinstance(inputs, np.ndarray):\n            inputs = [inputs]\n        \n        assert len(inputs) == len(self.inputs)\n        \n        if isinstance(inputs, (list,tuple)):\n            inputs = {self.inputs[i].name : input for i, input in enumerate(inputs)}\n        elif not isinstance(inputs, dict):        \n            raise ValueError(f\"inputs must be a list, tuple, or dict (instead got type '{type(inputs).__name__}')\")\n            \n        outputs = self.model.run(None, inputs)\n        \n        if return_dict:\n            return {self.outputs[i].name : output for i, output in enumerate(outputs)}\n            \n        if len(outputs) == 1:\n            return outputs[0]\n        \n        return outputs"
  },
  {
    "path": "jetson_voice/backends/riva/__init__.py",
    "content": "#!/usr/bin/env python3\r\n# coding: utf-8\r\n\r\nfrom .riva_asr import RivaASRService\r\nfrom .riva_tts import RivaTTSService\r\n"
  },
  {
    "path": "jetson_voice/backends/riva/riva_asr.py",
    "content": "#!/usr/bin/env python3\r\n# coding: utf-8\r\n\r\nimport os\r\nimport grpc\r\nimport queue\r\nimport threading\r\nimport logging\r\n\r\nimport riva_api.audio_pb2 as ra\r\nimport riva_api.riva_asr_pb2 as rasr\r\nimport riva_api.riva_asr_pb2_grpc as rasr_srv\r\n\r\nfrom jetson_voice import ASRService\r\nfrom jetson_voice.utils import audio_to_int16\r\n\r\n    \r\nclass RivaASRService(ASRService):\r\n    \"\"\"\r\n    Riva streaming ASR service.  \r\n    \"\"\"\r\n    def __init__(self, config, *args, **kwargs):\r\n        \"\"\"\r\n        Open a streaming channel to the Riva server for ASR.  This establishes a connection over GRPC \r\n        and sends/recieves the requests and responses asynchronously.  Incoming audio samples get put\r\n        into a request queue that GRPC picks up, and a thread waits on responses to come in.\r\n        \"\"\"\r\n        super(RivaASRService, self).__init__(config, *args, **kwargs)\r\n        \r\n        self.config.setdefault('server', 'localhost:50051')\r\n        self.config.setdefault('sample_rate', 16000)\r\n        self.config.setdefault('frame_length', 1.0)\r\n        self.config.setdefault('request_timeout', 2.0)      # how long to wait for new audio to come in\r\n        self.config.setdefault('response_timeout', 0.05)    # how long to wait for results from riva\r\n        self.config.setdefault('language_code', 'en-US')\r\n        self.config.setdefault('enable_automatic_punctuation', True)\r\n        self.config.setdefault('top_k', 1)\r\n\r\n        logging.info(f'Riva ASR service config:\\n{self.config}')\r\n        \r\n        self.channel = grpc.insecure_channel(self.config.server)\r\n        self.client = rasr_srv.RivaSpeechRecognitionStub(self.channel)\r\n        \r\n        self.recognition_config = rasr.RecognitionConfig(\r\n            encoding = ra.AudioEncoding.LINEAR_PCM,\r\n            sample_rate_hertz = self.config.sample_rate,\r\n            language_code = self.config.language_code,\r\n            max_alternatives = self.config.top_k,\r\n            enable_word_time_offsets = True,\r\n            enable_automatic_punctuation = self.config.enable_automatic_punctuation\r\n        )\r\n\r\n        self.streaming_config = rasr.StreamingRecognitionConfig(\r\n            config = self.recognition_config,\r\n            interim_results = True\r\n        )\r\n        \r\n        self.request_queue = queue.Queue()\r\n        self.request_queue.put(rasr.StreamingRecognizeRequest(streaming_config=self.streaming_config))\r\n         \r\n        self.responses = self.client.StreamingRecognize(self)\r\n        self.responses_queue = queue.Queue()\r\n        \r\n        self.response_thread = threading.Thread(target=self.recieve_responses)\r\n        self.response_thread.start()\r\n\r\n    def __call__(self, samples):\r\n        \"\"\"\r\n        Transcribe streaming audio samples to text, returning the running phrase.\r\n        Phrases are broken up when a break in the audio is detected (i.e. end of sentence)\r\n        \r\n        Parameters:\r\n          samples (array) -- Numpy array of audio samples.\r\n\r\n        Returns a list[dict] of the running transcripts with the following keys:\r\n        \r\n          text (string) -- the transcript of the current sentence\r\n          words (list[dict]) -- a list of word dicts that make up the sentence\r\n          end (bool) -- if true, end-of-sentence due to silence\r\n          \r\n        Each transcript represents one phrase/sentence.  When a sentence has been determined\r\n        to be ended, it will be marked with end=True.  Multiple sentence transcripts can be \r\n        returned if one just ended and another is beginning. \r\n        \"\"\"\r\n        samples = audio_to_int16(samples)\r\n\r\n        self.request_queue.put(rasr.StreamingRecognizeRequest(audio_content=samples.tobytes()))\r\n        \r\n        transcripts = []\r\n        \r\n        while True:\r\n            try:\r\n                transcripts.append(self.responses_queue.get(block=True, timeout=self.config.response_timeout))\r\n            except queue.Empty:\r\n                break\r\n\r\n        return transcripts\r\n \r\n    def __next__(self):\r\n        \"\"\"\r\n        Retrieve the next request containing audio samples to send to the Riva server.\r\n        This is implemented using an iterator interface as that is what GRPC expects.\r\n        \"\"\"\r\n        try:\r\n            request = self.request_queue.get(block=True, timeout=self.config.request_timeout)\r\n            return request\r\n        except queue.Empty:\r\n            logging.debug(f'{self.config.request_timeout} second timeout occurred waiting for audio samples, stopping Riva ASR service')\r\n            raise StopIteration\r\n        \r\n    def recieve_responses(self):\r\n        \"\"\"\r\n        Wait to recieve responses from the Riva server and parse them.\r\n        \"\"\"\r\n        logging.debug('starting Riva ASR service response reciever thread')\r\n        \r\n        for response in self.responses:  # this is blocking\r\n            if not response.results:\r\n                continue\r\n\r\n            result = response.results[0]\r\n\r\n            if not result.alternatives:\r\n                continue\r\n\r\n            text = result.alternatives[0].transcript\r\n            text = text.strip()\r\n            \r\n            if len(text) == 0:\r\n                continue\r\n                \r\n            self.responses_queue.put({\r\n                'text' : text,\r\n                'end' : result.is_final\r\n            })\r\n\r\n        logging.debug('exiting Riva ASR service response reciever thread')\r\n        \r\n    @property\r\n    def sample_rate(self):\r\n        \"\"\"\r\n        The sample rate that the model runs at (in Hz)\r\n        Input audio should be resampled to this rate.\r\n        \"\"\"\r\n        return self.config.sample_rate\r\n    \r\n    @property\r\n    def frame_length(self):\r\n        \"\"\"\r\n        Duration in seconds per frame / chunk.\r\n        \"\"\"\r\n        return self.config.frame_length\r\n        \r\n    @property\r\n    def chunk_size(self):\r\n        \"\"\"\r\n        Number of samples per frame/chunk (equal to frame_length * sample_rate)\r\n        \"\"\"\r\n        return int(self.frame_length * self.sample_rate)\r\n\r\n"
  },
  {
    "path": "jetson_voice/backends/riva/riva_tts.py",
    "content": "#!/usr/bin/env python3\r\n# coding: utf-8\r\n\r\nimport os\r\nimport grpc\r\nimport logging\r\nimport numpy as np\r\n\r\nimport riva_api.audio_pb2 as ra\r\nimport riva_api.riva_tts_pb2 as rtts\r\nimport riva_api.riva_tts_pb2_grpc as rtts_srv\r\n\r\nfrom jetson_voice import TTSService\r\n\r\n    \r\nclass RivaTTSService(TTSService):\r\n    \"\"\"\r\n    Riva streaming TTS service.  \r\n    \"\"\"\r\n    def __init__(self, config, *args, **kwargs):\r\n        \"\"\"\r\n        Open a streaming channel to the Riva server for TTS.  This establishes a connection over GRPC \r\n        and sends/recieves the requests and responses.\r\n        \"\"\"\r\n        super(RivaTTSService, self).__init__(config, *args, **kwargs)\r\n        \r\n        self.config.setdefault('server', 'localhost:50051')\r\n        self.config.setdefault('sample_rate', 22050)        # ignored (will always be 22.05KHz)\r\n        self.config.setdefault('voice_name', 'ljspeech')    # ignored\r\n        self.config.setdefault('language_code', 'en-US')\r\n\r\n        logging.info(f'Riva TTS service config:\\n{self.config}')\r\n        \r\n        self.channel = grpc.insecure_channel(self.config.server)\r\n        self.client = rtts_srv.RivaSpeechSynthesisStub(self.channel)\r\n\r\n    def __call__(self, text):\r\n        \"\"\"\r\n        Generate audio from text.\r\n        \r\n        Parameters:\r\n          text (string) -- The phrase to convert to audio.\r\n\r\n        Returns audio samples in a numpy array.\r\n        \"\"\"\r\n        req = rtts.SynthesizeSpeechRequest()\r\n        \r\n        req.text = text\r\n        req.language_code = self.config.language_code\r\n        req.sample_rate_hz = self.config.sample_rate\r\n        req.voice_name = self.config.voice_name\r\n        req.encoding = ra.AudioEncoding.LINEAR_PCM\r\n\r\n        resp = self.client.Synthesize(req)\r\n        \r\n        samples = np.frombuffer(resp.audio, dtype=np.float32)\r\n        return samples\r\n    \r\n    @property\r\n    def sample_rate(self):\r\n        \"\"\"\r\n        Get the output sample rate (in Hz)\r\n        \"\"\"\r\n        return self.config.sample_rate"
  },
  {
    "path": "jetson_voice/backends/tensorrt/__init__.py",
    "content": "#!/usr/bin/env python3\r\n# coding: utf-8\r\n\r\nfrom .trt_model import TRTModel\r\n\r\n"
  },
  {
    "path": "jetson_voice/backends/tensorrt/trt_binding.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nimport logging\nimport tensorrt as trt\n\nimport pycuda.driver as cuda\nimport pycuda.autoinit\n\n\nclass Binding:\n    \"\"\"\n    Represents an input/output tensor to the model.\n    \"\"\"\n    def __init__(self, model, index):\n        \"\"\"\n        Parameters:\n          model (TRTModel) -- parent model instance\n          index (int) -- index of the binding in the model\n        \"\"\"\n        self.model = model\n        self.index = index\n\n        self.name  = model.trt_engine.get_binding_name(index)\n        self.shape = tuple(model.trt_engine.get_binding_shape(index))\n        self.dtype = model.trt_engine.get_binding_dtype(index)\n        self.input = model.trt_engine.binding_is_input(index)\n        self.size  = max(trt.volume(self.shape) * self.dtype.itemsize, 0)\n        \n        self.dynamic = (self.size <= 0)   \n        self.profiles = []\n            \n        if self.input:\n            for i in range(model.trt_engine.num_optimization_profiles):\n                profile = model.trt_engine.get_profile_shape(i, index)\n                self.profiles.append(dict(\n                    min = profile[0],\n                    opt = profile[1],\n                    max = profile[2]))\n        \n        self.alloc()\n          \n    def alloc(self, shape=None):\n        \"\"\"\n        Allocate memory for the binding. alloc() is called automatically when needed.\n        If new shape is provided, it will update the internal state. \n        \"\"\"\n        if shape is not None:\n            self.shape = shape\n            \n        self.size = trt.volume(self.shape) * self.dtype.itemsize\n        \n        if self.size <= 0:  # dynamic with shape not yet set\n            self.host = None\n            self.device = None\n            return\n            \n        self.host = None if self.input else cuda.pagelocked_empty(self.shape, dtype=trt.nptype(self.dtype))\n        self.device = cuda.mem_alloc(self.size)\n        \n    def set_shape(self, shape):\n        \"\"\"\n        Set the shape of a dynamic input binding.\n        \"\"\"\n        if not self.dynamic:\n            raise ValueError(f\"binding '{self.name}' is not dynamic\")\n            \n        if not self.input:\n            raise ValueError(f\"binding '{self.name}' is not an input\")\n            \n        # check to see if the shape already matches\n        if self.shape == shape:\n            logging.debug(f\"binding '{self.name}' already has shape {shape}\")\n            return\n            \n        logging.debug(f\"binding '{self.name}' has new shape {shape}\")\n        \n        # set the new shape\n        if not self.model.trt_context.set_binding_shape(self.index, shape):\n            raise ValueError(f\"failed to set binding '{self.name}' with shape {shape}\")\n           \n        # re-allocate tensor memory\n        self.alloc(shape)\n    \n    def query_shape(self):\n        \"\"\"\n        Updates the shape of a dynamic output binding.\n        \"\"\"\n        if not self.dynamic:\n            return\n            \n        if self.input:\n            raise ValueError(f\"binding '{self.name}' is not an output\")\n        \n        # get the new shape\n        shape = tuple(self.model.trt_context.get_binding_shape(self.index))\n        \n        # check to see if the shape already matches\n        if self.shape == shape:\n            logging.debug(f\"binding '{self.name}' already has shape {shape}\")\n            return\n        \n        logging.debug(f\"binding '{self.name}' has new output shape {shape}\")\n        \n        # re-allocate tensor memory\n        self.alloc(shape)\n        return shape\n        \n    def __str__(self):\n        return (\n            f\"binding {self.index} - '{self.name}'\\n\"\n            f\"   input:    {self.input}\\n\"\n            f\"   shape:    {self.shape}\\n\"\n            f\"   dtype:    {self.dtype}\\n\"\n            f\"   size:     {self.size}\\n\"\n            f\"   dynamic:  {self.dynamic}\\n\"\n            f\"   profiles: {self.profiles}\\n\"\n        )"
  },
  {
    "path": "jetson_voice/backends/tensorrt/trt_builder.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nimport os\nimport time\nimport json\nimport logging\nimport tensorrt as trt\n\nimport pycuda.driver as cuda\nimport pycuda.autoinit\n\nTRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)\n\ndef build_engine(config, \n                 output=None, \n                 precision='fp16',\n                 batch_size=1,\n                 dynamic_shapes=None,\n                 workspace=128, \n                 parse_only=False):\n    \"\"\"\n    Build TensorRT engine from ONNX model.\n    \n    Parameters:\n      model (string) -- path to ONNX model\n      config (string) -- path to model configuration json (will be inferred from model path if empty)\n      output (string) -- path to output serialized TensorRT engine (will be inferred from model path if empty)\n      precision (string) -- fp32 or fp16 (int8 not currently supported)\n      batch_size (int) -- the maximum batch size (default 1)\n      dynamic_shape (dict) -- dynamic shape profiles for min/max/opt\n      workspace (int) -- builder workspace memory size (in MB)\n      parse_only (bool) -- if true, test parsing the model before exiting without building the TensorRT engine\n      \n    Returns the built TensorRT engine (ICudaEngine)\n    \"\"\"\n    # set default output path\n    if output is None or output == '':\n        output = f'{os.path.splitext(config.model_path)[0]}.engine'\n\n    # create TensorRT resources\n    builder = trt.Builder(TRT_LOGGER)\n    builder_config = builder.create_builder_config()\n    network = builder.create_network(1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))\n    parser = trt.OnnxParser(network, TRT_LOGGER)\n    \n    builder_config.max_workspace_size = workspace * 1 << 20\n    \n    # set precision\n    precision = precision.lower()\n    \n    if precision == 'fp16':\n        builder_config.set_flag(trt.BuilderFlag.FP16)\n        logging.info(f'enabled FP16 precision')\n    elif precision == 'int8':\n        # https://github.com/NVIDIA/TensorRT/blob/d7baf010e4396c87d58e4d8a33052c01c2d89325/demo/BERT/builder.py#L592\n        raise NotImplementedError('INT8 support not yet implemented')\n        \n    # load the model (from ONNX)\n    logging.info(f'loading {config.model_path}')\n    \n    with open(config.model_path, 'rb') as model_file:\n        if not parser.parse(model_file.read()):\n            logging.error(f'failed to parse ONNX model {config.model_path}')\n            for error in range(parser.num_errors): \n                print (parser.get_error(error))\n            return None \n\n    # create dynamic shape profile\n    # TODO refactor this to an abstract .get_dynamic_shapes() implementation in each subclass\n    # TODO this currently uses same shape for all inputs - allow for different shape profiles\n    profile = builder.create_optimization_profile()\n    opt_shape = None\n    \n    \"\"\"\n    if model_type == 'qa' or model_type == 'text_classification' or model_type == 'token_classification':\n        min_shape = (1, 1)  # (batch_size, sequence_length)\n        max_shape = (batch_size, model_config['dataset']['max_seq_length'])\n    elif model_type == 'intent_slot':\n        min_shape = (1, 1)  # (batch_size, sequence_length)\n        max_shape = (batch_size, model_config['language_model']['max_seq_length'])\n    elif model_type == 'asr':\n        features = model_config['preprocessor']['features']\n        sample_rate = model_config['preprocessor']['sample_rate']\n        sample_to_fft = 1.0 / 160.0  # rough conversion from samples to MEL spectrogram dims\n        sample_multiplier = sample_rate * sample_to_fft\n        \n        min_shape = (batch_size, features, int(0.5 * sample_multiplier))  # minimum plausible frame length\n        opt_shape = (batch_size, features, int(1.2 * sample_multiplier))  # default of .1s overlap factor (1,64,121)\n        max_shape = (batch_size, features, int(3.0 * sample_multiplier))  # enough for 1s overlap factor\n    elif model_type == 'asr_classification':\n        features = model_config['preprocessor']['n_mels']\n        sample_rate = model_config['sample_rate']\n        sample_to_fft = 1.0 / 160.0  # rough conversion from samples to MEL spectrogram dims\n        sample_multiplier = sample_rate * sample_to_fft\n        \n        min_shape = (batch_size, features, int(0.5 * sample_multiplier))  # minimum plausible frame length\n        opt_shape = (batch_size, features, int(1.2 * sample_multiplier))  # default of .1s overlap factor (1,64,121)\n        max_shape = (batch_size, features, int(3.0 * sample_multiplier))  # enough for 1s overlap factor\n    elif model_type == 'tts_vocoder':\n        min_shape = (batch_size, model_config['features'], 1)\n        opt_shape = (batch_size, model_config['features'], 160)  # ~5-6 words\n        max_shape = (batch_size, model_config['features'], 512)  # ~15-20 words?\n    else:\n        raise NotImplementedError(f\"model type '{model_type}' is unrecognized or not supported\")\n    \"\"\"           \n    \n    # TODO support different shape profiles for different input tensors\n    if dynamic_shapes is not None:        \n        if 'min' not in dynamic_shapes:\n            dynamic_shapes['min'] = dynamic_shapes['max']\n            \n        if 'opt' not in dynamic_shapes:\n            dynamic_shapes['opt'] = dynamic_shapes['max']\n            \n        for i in range(network.num_inputs):  # TODO confirm that input is in fact dynamic\n            profile.set_shape(network.get_input(i).name, min=dynamic_shapes['min'], opt=dynamic_shapes['opt'], max=dynamic_shapes['max'])\n\n        builder_config.add_optimization_profile(profile)\n                    \n    def print_summary():\n        print('')\n        print('----------------------------------------------------')\n        print(' BUILDER CONFIGURATION')\n        print('----------------------------------------------------')\n        print(f'  - model     {config.model_path}')\n        print(f'  - config    {config.path}')\n        print(f'  - output    {output}')\n        print(f'  - type      {config.type}')\n        print(f'  - layers    {network.num_layers}')\n        print(f'  - inputs    {network.num_inputs}')\n        print(f'  - outputs   {network.num_outputs}')\n        print(f'  - precision {precision}')\n        print(f'  - workspace {workspace}')\n        print('')\n        \n        for i in range(network.num_inputs):\n            tensor = network.get_input(i)\n            \n            print(f'  - input {i}:')\n            print(f'      - name     {tensor.name}')\n            print(f'      - shape    {tensor.shape}')\n            print(f'      - dtype    {tensor.dtype}')\n            \n        for i in range(network.num_outputs):\n            tensor = network.get_output(i)\n            \n            print(f'  - output {i}:')\n            print(f'      - name     {tensor.name}')\n            print(f'      - shape    {tensor.shape}')\n            print(f'      - dtype    {tensor.dtype}')\n           \n    print_summary()\n    \n    if parse_only:\n        return None\n    \n    # build the engine\n    build_start_time = time.time()\n    \n    engine = builder.build_engine(network, builder_config)\n    \n    if engine is None:\n        raise ValueError(f\"failed to build TensorRT engine for '{config.model_path}'\")\n        \n    build_time_elapsed = (time.time() - build_start_time)\n    print(f'\\nbuilt engine in {build_time_elapsed} seconds')\n\n    print_summary()\n    \n    # save engine\n    print('\\nserializing engine...')\n    serialized_engine = engine.serialize()\n    with open(output, \"wb\") as engine_file:\n        engine_file.write(serialized_engine)\n    print(f'saved engine to {output}')\n        \n    return engine\n        \n\n'''\nif __name__ == \"__main__\":\n\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    \n    parser.add_argument('--config', default='', type=str)\n    parser.add_argument('--output', default='', type=str)\n    parser.add_argument('--precision', default='fp16', choices=['fp32', 'fp16', 'int8'], type=str)\n    parser.add_argument('--batch-size', default=1, type=int) # max batch size\n    parser.add_argument('--workspace', default=utils.DEFAULT_WORKSPACE, type=int)\n    parser.add_argument('--parse-only', action='store_true')\n    \n    args = parser.parse_args()\n    print(args)\n    \n    build_engine(config=args.config,\n                 output=args.output,\n                 precision=args.precision,\n                 batch_size=args.batch_size,\n                 workspace=args.workspace,\n                 parse_only=args.parse_only)\n'''\n\n"
  },
  {
    "path": "jetson_voice/backends/tensorrt/trt_model.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nimport os\nimport time\nimport json\nimport logging\nimport pprint\n\nimport numpy as np\nimport tensorrt as trt\n\nimport pycuda.driver as cuda\nimport pycuda.autoinit\n\nfrom .trt_builder import build_engine, TRT_LOGGER\nfrom .trt_binding import Binding\n\n\nclass TRTModel:\n    \"\"\"\n    Base class for TensorRT models.\n    \"\"\"\n    def __init__(self, config, dynamic_shapes=None, *args, **kwargs):\n        \"\"\"\n        Load a TensorRT model from ONNX or serialized TensorRT engine.\n        \n        Parameters:\n          config (ConfigDict) -- configuration dict\n          dynamic_shapes (dict) -- dynamic shape profiles for min/max/opt\n        \"\"\"\n        self.config = config\n            \n        # determine if the TensorRT engine already exists\n        model_root, model_ext = os.path.splitext(self.config.model_path)\n        model_ext = model_ext.lower()\n        \n        if model_ext == '.onnx':\n            engine_path = model_root + '.engine'\n            if os.path.exists(engine_path):\n                logging.info(f'loading cached TensorRT engine from {engine_path}')\n                self.config.model_path = engine_path\n                model_ext = '.engine'\n                \n        # either build or load TensorRT engine\n        if model_ext == '.onnx':\n            self.trt_engine = build_engine(self.config, dynamic_shapes=dynamic_shapes)\n        elif model_ext == '.engine' or model_ext == '.plan':\n            with open(self.config.model_path, 'rb') as f:\n                self.trt_runtime = trt.Runtime(TRT_LOGGER)\n                self.trt_engine  = self.trt_runtime.deserialize_cuda_engine(f.read())\n        else:\n            raise ValueError(f\"invalid model extension '{model_ext}' (should be .onnx, .engine, or .plan)\")\n            \n        if self.trt_engine is None:\n            raise IOError(f'failed to load TensorRT engine from {self.model_path}')\n                \n        self.trt_context = self.trt_engine.create_execution_context()\n        logging.info(f'loaded TensorRT engine from {self.config.model_path}')\n\n        # create a stream in which to copy inputs/outputs and run inference\n        self.stream = cuda.Stream()\n        \n        # enumerate bindings\n        self.bindings = []\n        self.inputs  = []\n        self.outputs = []\n\n        for i in range(len(self.trt_engine)):\n            binding = Binding(self, i)\n            self.bindings.append(binding)\n            \n            if binding.input:\n                self.inputs.append(binding)\n            else:\n                self.outputs.append(binding)\n        \n        for binding in self.bindings:\n            print(f'\\n{binding}')\n\n    def execute(self, inputs, sync=True, return_dict=False, **kwargs):\n        \"\"\"\n        Run the DNN model in TensorRT.  The inputs are provided as numpy arrays in a list/tuple/dict.\n        Note that run() doesn't perform any pre/post-processing - this is typically done in subclasses.\n        \n        Parameters:\n          inputs (array, list[array], dict[array]) -- the network inputs as numpy array(s).\n                         If there is only one input, it can be provided as a single numpy array.\n                         If there are multiple inputs, they can be provided as numpy arrays in a\n                         list, tuple, or dict.  Inputs in lists and tuples are assumed to be in the\n                         same order as the input bindings.  Inputs in dicts should have keys with the\n                         same names as the input bindings.\n          sync (bool) -- If True (default), will wait for the GPU to be done processing before returning.\n          return_dict (bool) -- If True, the results will be returned in a dict of numpy arrays, where the\n                                keys are the names of the output binding names. By default, the results will \n                                be returned in a list of numpy arrays, in the same order as the output bindings.\n          \n        Returns the model output as a numpy array (if only one output), list[ndarray], or dict[ndarray].\n        \"\"\"\n        if isinstance(inputs, np.ndarray):\n            inputs = [inputs]\n        \n        assert len(inputs) == len(self.inputs)\n        \n        # setup inputs + copy to GPU\n        def setup_binding(binding, input):\n            input = input.astype(trt.nptype(binding.dtype), copy=False)\n            if binding.dynamic: \n                binding.set_shape(input.shape)\n            cuda.memcpy_htod_async(binding.device, np.ascontiguousarray(input), self.stream)\n            \n        if isinstance(inputs, (list,tuple)):\n            for idx, input in enumerate(inputs):\n                setup_binding(self.bindings[idx], input)\n        elif isinstance(inputs, dict):        \n            for binding_name in inputs:\n                setup_binding(self.find_binding(binding_name), inputs[binding_name])\n        else:\n            raise ValueError(f\"inputs must be a list, tuple, or dict (instead got type '{type(inputs).__name__}')\")\n            \n        assert self.trt_context.all_binding_shapes_specified\n        assert self.trt_context.all_shape_inputs_specified \n        \n        # query new dynamic output shapes\n        for output in self.outputs:\n            output.query_shape()\n\n        # run inference\n        self.trt_context.execute_async_v2(\n            bindings=[int(binding.device) for binding in self.bindings], \n            stream_handle=self.stream.handle\n        )\n          \n        # copy outputs to CPU\n        for output in self.outputs:\n            cuda.memcpy_dtoh_async(output.host, output.device, self.stream)\n          \n        # wait for completion\n        if sync:\n            self.stream.synchronize()\n            \n        # return results\n        if return_dict:\n            results = {}\n            for output in self.outputs:\n                results[output.name] = output.host\n            return results\n        else:\n            if len(self.outputs) == 1:\n                return self.outputs[0].host\n            else:\n                return tuple([output.host for output in self.outputs])\n\n    def find_binding(self, name):\n        \"\"\"\n        Lookup an input/output binding by name\n        \"\"\"\n        for binding in self.bindings:\n            if binding.name == name: \n                return binding   \n        logging.error(f\"couldn't find binding with name '{name}'\")\n        return None\n        \n    def set_shape(self, binding, shape):\n        \"\"\"\n        Set the shape of a dynamic binding.\n        \"\"\"\n        if isinstance(binding, int):\n            binding = self.bindings[binding]\n        elif isinstance(binding, str):\n            binding = self.find_binding(binding)\n        elif not isinstance(binding, dict):\n            raise ValueError(f'binding must be specified as int, string, or dict (got {type(binding).__name__})')\n            \n        binding.set_shape(shape)\n    \n"
  },
  {
    "path": "jetson_voice/models/__init__.py",
    "content": "#!/usr/bin/env python3\r\n# coding: utf-8\r\n\r\nfrom .asr import ASREngine\r\nfrom .nlp import IntentSlotEngine, QuestionAnswerEngine, TextClassificationEngine, TokenClassificationEngine\r\nfrom .tts import TTSEngine"
  },
  {
    "path": "jetson_voice/models/asr/__init__.py",
    "content": "#!/usr/bin/env python3\r\n# coding: utf-8\r\n\r\nfrom .asr_engine import ASREngine\r\n"
  },
  {
    "path": "jetson_voice/models/asr/asr_engine.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nimport os\nimport time\nimport pprint\nimport logging\nimport importlib\n\nimport torch\nimport numpy as np\n\nfrom .ctc_decoder import CTCDecoder\n\nfrom jetson_voice.asr import ASRService\nfrom jetson_voice.utils import audio_to_float, global_config, load_model, softmax\n\n      \nclass ASREngine(ASRService):\n    \"\"\"\n    Streaming ASR (Automatic Speech Recognition) model in TensorRT or onnxruntime.\n    This model is primarily designed to be used on a live audio source like a microphone.\n    \"\"\"\n    def __init__(self, config, *args, **kwargs):\n        \"\"\"\n        Loads a streaming ASR model from ONNX or serialized TensorRT engine.\n        \n        Parameters:\n          model (string) -- path to ONNX model or serialized TensorRT engine/plan\n          config (string) -- path to model configuration json (will be inferred from model path if empty)\n        \"\"\"\n        super(ASREngine, self).__init__(config, *args, **kwargs)\n\n        if self.config.type != 'asr' and self.config.type != 'asr_classification':\n            raise ValueError(f\"{self.config.model_path} isn't an ASR model (type '{self.config.type}'\")\n\n        # set some default config options that are non-standard in nemo\n        if 'streaming' not in self.config:\n            self.config['streaming'] = {}\n        \n        self.config['streaming'].setdefault('frame_length', 1.0)     # duration of signal frame, seconds (TODO shorter defaults for VAD/command classifiers)\n        self.config['streaming'].setdefault('frame_overlap', 0.5)    # duration of overlap before/after current frame, seconds\n        \n        # some config changes for streaming\n        if not self.classification:\n            self.config['preprocessor']['dither'] = 0.0\n            self.config['preprocessor']['pad_to'] = 0\n        \n            if 'ctc_decoder' not in self.config:\n                self.config['ctc_decoder'] = {}\n                    \n            self.config['ctc_decoder'].setdefault('type', 'greedy')        # greedy or beamsearch\n            self.config['ctc_decoder'].setdefault('add_punctuation', True) # add period to the end of sentences\n        \n            if 'add_punctuation' in kwargs:\n                self.config['ctc_decoder']['add_punctuation'] = kwargs['add_punctuation']\n                logging.info(f\"add_punctuation = {kwargs['add_punctuation']}\")\n                \n        if not self.classification and self.config['preprocessor']['features'] == 64:   # TODO normalization coefficients for citrinet (N=80)\n            normalization = {}\n\n            normalization['fixed_mean'] = [\n                 -14.95827016, -12.71798736, -11.76067913, -10.83311182,\n                 -10.6746914,  -10.15163465, -10.05378331, -9.53918999,\n                 -9.41858904,  -9.23382904,  -9.46470918,  -9.56037,\n                 -9.57434245,  -9.47498732,  -9.7635205,   -10.08113074,\n                 -10.05454561, -9.81112681,  -9.68673603,  -9.83652977,\n                 -9.90046248,  -9.85404766,  -9.92560366,  -9.95440354,\n                 -10.17162966, -9.90102482,  -9.47471025,  -9.54416855,\n                 -10.07109475, -9.98249912,  -9.74359465,  -9.55632283,\n                 -9.23399915,  -9.36487649,  -9.81791084,  -9.56799225,\n                 -9.70630899,  -9.85148006,  -9.8594418,   -10.01378735,\n                 -9.98505315,  -9.62016094,  -10.342285,   -10.41070709,\n                 -10.10687659, -10.14536695, -10.30828702, -10.23542833,\n                 -10.88546868, -11.31723646, -11.46087382, -11.54877829,\n                 -11.62400934, -11.92190509, -12.14063815, -11.65130117,\n                 -11.58308531, -12.22214663, -12.42927197, -12.58039805,\n                 -13.10098969, -13.14345864, -13.31835645, -14.47345634]\n                 \n            normalization['fixed_std'] = [\n                 3.81402054, 4.12647781, 4.05007065, 3.87790987,\n                 3.74721178, 3.68377423, 3.69344,    3.54001005,\n                 3.59530412, 3.63752368, 3.62826417, 3.56488469,\n                 3.53740577, 3.68313898, 3.67138151, 3.55707266,\n                 3.54919572, 3.55721289, 3.56723346, 3.46029304,\n                 3.44119672, 3.49030548, 3.39328435, 3.28244406,\n                 3.28001423, 3.26744937, 3.46692348, 3.35378948,\n                 2.96330901, 2.97663111, 3.04575148, 2.89717604,\n                 2.95659301, 2.90181116, 2.7111687,  2.93041291,\n                 2.86647897, 2.73473181, 2.71495654, 2.75543763,\n                 2.79174615, 2.96076456, 2.57376336, 2.68789782,\n                 2.90930817, 2.90412004, 2.76187531, 2.89905006,\n                 2.65896173, 2.81032176, 2.87769857, 2.84665271,\n                 2.80863137, 2.80707634, 2.83752184, 3.01914511,\n                 2.92046439, 2.78461139, 2.90034605, 2.94599508,\n                 2.99099718, 3.0167554,  3.04649716, 2.94116777]\n                 \n            self.config['preprocessor']['normalize'] = normalization\n        \n        # create preprocessor instance\n        preprocessor_name = self.config['preprocessor']['_target_'].rsplit(\".\", 1)\n        preprocessor_class = getattr(importlib.import_module(preprocessor_name[0]), preprocessor_name[1])\n        logging.debug(f'ASR preprocessor - {preprocessor_class}')\n\n        preprocessor_config = self.config['preprocessor'].copy()\n        preprocessor_config.pop('_target_')\n\n        self.preprocessor = preprocessor_class(**preprocessor_config)\n\n        # load the model\n        features = self.config.preprocessor.n_mels if self.classification else self.config.preprocessor.features\n        time_to_fft = self.sample_rate * (1.0 / 160.0)     # rough conversion from samples to MEL spectrogram dims\n        \n        dynamic_shapes = {\n            'min' : (1, features, int(0.1 * time_to_fft)), # minimum plausible frame length\n            'opt' : (1, features, int(1.5 * time_to_fft)), # default of .5s overlap factor (1,64,121)\n            'max' : (1, features, int(3.0 * time_to_fft))  # enough for 2s overlap factor\n        }\n        \n        self.model = load_model(self.config, dynamic_shapes)\n        \n        # create CTC decoder\n        if not self.classification:\n            self.ctc_decoder = CTCDecoder.from_config(self.config['ctc_decoder'],\n                                                      self.config['decoder']['vocabulary'],\n                                                      os.path.dirname(self.config.model_path))\n                                                      \n            logging.info(f\"CTC decoder type: '{self.ctc_decoder.type}'\")\n            \n        # create streaming buffer\n        self.n_frame_len = int(self.frame_length * self.sample_rate)\n        self.n_frame_overlap = int(self.frame_overlap * self.sample_rate)\n        \n        self.buffer_length = self.n_frame_len + self.n_frame_overlap\n        self.buffer_duration = self.buffer_length / self.sample_rate\n        \n        self.buffer = np.zeros(shape=self.buffer_length, dtype=np.float32)  # 2*self.n_frame_overlap\n    \n        \n    def __call__(self, samples):\n        \"\"\"\n        Transcribe streaming audio samples to text, returning the running phrase.\n        Phrases are broken up when a break in the audio is detected (i.e. end of sentence)\n        \n        Parameters:\n          samples (array) -- Numpy array of audio samples.\n\n        Returns a dict of the running phrase.\n          transcript (string) -- the current transcript\n          latest (string) -- the latest additions to the transcript\n          end (bool) -- if true, end-of-sequence due to silence\n        \"\"\"\n        samples = audio_to_float(samples)\n        \n        if len(samples) < self.n_frame_len:\n            samples = np.pad(samples, [0, self.n_frame_len - len(samples)], 'constant')\n            \n        self.buffer[:self.n_frame_overlap] = self.buffer[-self.n_frame_overlap:]\n        self.buffer[self.n_frame_overlap:] = samples\n        \n        if global_config.profile: preprocess_begin = time.perf_counter()\n        \n        # apply pre-processing\n        preprocessed_signal, _ = self.preprocessor(\n            input_signal=torch.as_tensor(self.buffer, dtype=torch.float32).unsqueeze(dim=0), \n            length=torch.as_tensor(self.buffer.size, dtype=torch.int64).unsqueeze(dim=0)\n        )\n\n        if global_config.profile:\n            logging.info(f'preprocess time: {time.perf_counter() - preprocess_begin}')\n            network_begin = time.perf_counter()\n        \n        # run the asr model\n        logits = self.model.execute(torch_to_numpy(preprocessed_signal))\n        logits = np.squeeze(logits)\n        logits = softmax(logits, axis=-1)\n\n        if global_config.profile: logging.info(f'network time: {time.perf_counter() - network_begin}')\n        \n        self.timestep_duration = self.buffer_duration / logits.shape[0]\n        self.n_timesteps_frame = int(self.frame_length / self.timestep_duration)\n        self.n_timesteps_overlap = int(self.frame_overlap / self.timestep_duration)\n\n        if self.classification:\n            argmax = np.argmax(logits)\n            prob = logits[argmax]\n            return (self.config['labels'][argmax], prob)\n        else:\n            self.ctc_decoder.set_timestep_duration(self.timestep_duration)\n            self.ctc_decoder.set_timestep_delta(self.n_timesteps_frame)\n\n            if global_config.profile: ctc_decoder_begin = time.perf_counter()\n            transcripts = self.ctc_decoder.decode(logits)\n            if global_config.profile: logging.info(f'ctc_decoder time: {time.perf_counter() - ctc_decoder_begin}')\n            \n            return transcripts\n\n    @property\n    def classification(self):\n        \"\"\"\n        Returns true if this is an ASR classification model.\n        \"\"\"\n        return self.config.type == 'asr_classification'\n        \n    @property\n    def sample_rate(self):\n        \"\"\"\n        The sample rate that the model runs at.\n        Input audio should be resampled to this rate.\n        \"\"\"\n        return self.config['sample_rate'] if self.classification else self.config['preprocessor']['sample_rate']\n        \n    @property\n    def frame_length(self):\n        \"\"\"\n        Duration in seconds per frame / chunk.\n        \"\"\"\n        return self.config['streaming']['frame_length']\n        \n    @property\n    def frame_overlap(self):\n        \"\"\"\n        Duration of overlap in seconds before/after current frame.\n        \"\"\"\n        return self.config['streaming']['frame_overlap']\n    \n    @property\n    def chunk_size(self):\n        \"\"\"\n        Number of samples per frame/chunk (equal to frame_length * sample_rate)\n        \"\"\"\n        return self.n_frame_len\n\n\ndef torch_to_numpy(tensor):\n    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()\n    \n                    "
  },
  {
    "path": "jetson_voice/models/asr/ctc_beamsearch.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nimport os\nimport logging\n\nfrom .ctc_decoder import CTCDecoder\nfrom .ctc_utils import find_silent_intervals, merge_words, rebase_word_times, split_words, transcript_from_words\n\nfrom ctc_decoders import Scorer\nfrom swig_decoders import BeamDecoder, ctc_beam_search_decoder_ex\n\nfrom jetson_voice.utils import global_config\n\n\nclass CTCBeamSearchDecoder(CTCDecoder):\n    \"\"\"\n    CTC beam search decoder that optionally uses a language model.\n    \"\"\"\n    def __init__(self, config, vocab, resource_path=None):\n        \"\"\"\n        Create a new CTCBeamSearchDecoder.\n        \n        See CTCDecoder.from_config() to automatically create\n        the correct type of instance dependening on config.\n        \"\"\"\n        super().__init__(config, vocab)\n        self.config.setdefault('word_threshold', -1000.0)\n        self.reset()\n        \n        self.scorer = None    \n        #self.num_cores = max(os.cpu_count(), 1)\n        \n        # set default config\n        # https://github.com/NVIDIA/NeMo/blob/855ce265b80c0dc40f4f06ece76d2c9d6ca1be8d/nemo/collections/asr/modules/beam_search_decoder.py#L21\n        self.config.setdefault('language_model', None)\n        self.config.setdefault('beam_width', 32)#128)\n        self.config.setdefault('alpha', 0.7 if self.language_model else 0.0)\n        self.config.setdefault('beta', 0.0)\n        self.config.setdefault('cutoff_prob', 1.0)\n        self.config.setdefault('cutoff_top_n', 40)\n        self.config.setdefault('top_k', 3)\n        \n        # check for language model file\n        if self.language_model:\n            if not os.path.isfile(self.language_model):\n                self.config['language_model'] = os.path.join(resource_path, self.language_model)\n                if not os.path.isfile(self.language_model):\n                    raise IOError(f\"language model file '{self.language_model}' does not exist\")\n                    \n        logging.info('creating CTCBeamSearchDecoder')\n        logging.info(str(self.config))\n        \n        # create scorer\n        if self.language_model:\n            self.scorer = Scorer(self.config['alpha'],\n                                 self.config['beta'],\n                                 model_path=self.language_model,\n                                 vocabulary=self.vocab)\n            \n    def decode(self, logits):\n        \"\"\"\n        Decode logits into words, and merge the new words with the\n        previous words from the running transcript.\n        \n        Returns the running transcript as a list of word dictionaries, \n        where each word dict has he following keys:\n        \n           'text' (str) -- the text of the word\n           'score' (float) -- the probability of the word\n           'start_time' (int) -- the start time of the word (in timesteps)\n           'end_time' (int) -- the end time of the word (in timesteps)\n           \n        Note that the start/end times are transformed from timestamps into\n        seconds by the ASR engine after CTCDecoder.decode() is called.\n        \"\"\"\n        results = ctc_beam_search_decoder_ex(\n            logits.tolist(), \n            self.vocab,\n            self.config['beam_width'], \n            self.config['cutoff_prob'], \n            self.config['cutoff_top_n'], \n            self.config['top_k'],\n            self.timestep,\n            self.scorer)\n        \n        \n        if global_config.debug:\n            print('BeamSearch results', len(results))\n            for idx, result in enumerate(results):\n                print(f\"  beam {idx} [{result.score:.3f}] '{result.text}'\")\n                for word_idx, word in enumerate(result.words):\n                    print(f\"    word {word_idx} [{word.start_time}:{word.end_time} {word.score:.3f}] '{word.text}'\")\n                \n        words = [{\n            'text' : word.text,\n            'score' : word.score,\n            'start_time' : word.start_time,\n            'end_time' : word.end_time\n        } for word in results[0].words]\n        \n        # merge new words with past words\n        self.words = merge_words(self.words, words, self.config['word_threshold'], 'similarity')\n        \n        # look for silent/EOS intervals\n        silent_intervals = find_silent_intervals(logits, len(self.vocab), self.timesteps_silence, self.timestep) \n        \n        if global_config.debug: \n            print(f'silent intervals:  {silent_intervals}')\n\n        self.timestep += self.timestep_delta\n        \n        # split the words at EOS intervals\n        if len(silent_intervals) > 0:\n            wordlists = split_words(self.words, silent_intervals)\n            transcripts = []\n            \n            for idx, wordlist in enumerate(wordlists):\n                # ignore blanks (silence after EOS has already occurred)\n                if len(wordlist) == 0:\n                    continue\n                    \n                # if there is only one wordlist, then it must be EOS\n                # if there are multiple, then the last one is not EOS\n                end = (len(wordlists) == 1) or (idx < (len(wordlists) - 1))\n                \n                if end:\n                    wordlist = rebase_word_times(wordlist)\n                    self.reset()            # TODO reset timesteps counter correctly\n                else:\n                    self.words = wordlist   \n                    \n                transcripts.append((wordlist, end))\n        else:\n            transcripts = [(self.words, False)]\n\n        return [{\n            'text' : transcript_from_words(words, scores=global_config.debug, times=global_config.debug, end=end, add_punctuation=self.config['add_punctuation']),\n            'words' : words,\n            'end' : end\n        } for words, end in transcripts]\n        \n    def reset(self):\n        \"\"\"\n        Reset the CTC decoder state at EOS (end of sentence)\n        \"\"\"\n        #self.timestep = 0\n        #self.tail_silence = 0\n        self.words = []\n        \n    @property\n    def language_model(self):\n        return self.config['language_model']\n "
  },
  {
    "path": "jetson_voice/models/asr/ctc_decoder.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\n        \nclass CTCDecoder:\n    \"\"\"\n    CTC decoder base class for ASR.\n    \"\"\"    \n    @staticmethod\n    def from_config(config, vocab, resource_path=None):\n        \"\"\"\n        Static factory function to instantiate the correct\n        CTC decoder instance type from the config.\n        \n           config['type'] == 'greedy' -> CTCGreedyDecoder\n           config['type'] == 'beamsearch' -> CTCBeamSearchDecoder\n        \"\"\"\n        type = config['type'].lower()\n        \n        if type == 'greedy':\n            from .ctc_greedy import CTCGreedyDecoder\n            return CTCGreedyDecoder(config, vocab)\n        elif type == \"beamsearch\":\n            from .ctc_beamsearch import CTCBeamSearchDecoder\n            return CTCBeamSearchDecoder(config, vocab, resource_path)\n        else:\n            raise ValueError(f\"invalid/unrecognized CTC decoder type '{type}'\")\n            \n    def __init__(self, config, vocab):\n        \"\"\"\n        See CTCDecoder.from_config() to automatically create\n        the correct type of instance dependening on config.\n        \"\"\"\n        self.config = config\n        self.vocab = vocab\n        self.timestep = 0\n        \n        self.config.setdefault('vad_eos_duration', 0.65)  # max silent time until end-of-sentence\n        self.config.setdefault('timestep_offset', 5)      # number of symbols to drop for smooth streaming\n        \n    def decode(self, logits):\n        \"\"\"\n        Decode logits into words, and merge the new words with the\n        previous words from the running transcript.\n        \n        Returns the running transcript as a list of word dictionaries, \n        where each word dict has he following keys:\n        \n           'text' (str) -- the text of the word\n           'score' (float) -- the probability of the word\n           'start_time' (int) -- the start time of the word (in timesteps)\n           'end_time' (int) -- the end time of the word (in timesteps)\n           \n        Note that the start/end times are transformed from timestamps into\n        seconds by the ASR engine after CTCDecoder.decode() is called.\n        \"\"\"\n        pass\n        \n    def reset(self):\n        \"\"\"\n        Reset the CTC decoder state at EOS (end of sentence)\n        \"\"\"\n        pass\n\n    def set_timestep(self, timestep):\n        \"\"\"\n        Set the current timestep.\n        \"\"\"\n        self.timestep = timestep\n    \n    def set_timestep_delta(self, offset):\n        \"\"\"\n        Set the number of timesteps per frame.\n        \"\"\"\n        self.timestep_delta = offset - self.config['timestep_offset']\n        \n    def set_timestep_duration(self, duration):\n        \"\"\"\n        Set the duration of each timestep, in seconds.\n        \"\"\"\n        self.timestep_duration = duration\n        self.timesteps_silence = self.config['vad_eos_duration'] / self.timestep_duration\n             \n    @property\n    def type(self):\n        \"\"\"\n        Return the CTC decoder type string ('greedy' or 'beamsearch')\n        \"\"\"\n        return self.config['type'].lower() \n        \n "
  },
  {
    "path": "jetson_voice/models/asr/ctc_greedy.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nimport string\nimport numpy as np\n\nfrom .ctc_decoder import CTCDecoder\nfrom .ctc_utils import merge_words, transcript_from_words\n\nfrom jetson_voice.utils import global_config\n\n\nclass CTCGreedyDecoder(CTCDecoder):\n    \"\"\"\n    CTC greedy decoder that simply chooses the highest-probability logits.\n    \"\"\"\n    def __init__(self, config, vocab):\n        \"\"\"\n        Create a new CTCGreedyDecoder.\n        TODO document config.\n        \n        See CTCDecoder.from_config() to automatically create\n        the correct type of instance dependening on config.\n        \"\"\"\n        super().__init__(config, vocab)\n        \n        self.config.setdefault('word_threshold', 0.1)\n        \n        # add blank symbol to vocabulary\n        if '_' not in vocab:\n            self.vocab = vocab.copy()\n            self.vocab.append('_')\n            \n        self.reset()\n        \n    def decode(self, logits):\n        \"\"\"\n        Decode logits into words, and merge the new words with the\n        previous words from the running transcript.\n        \n        Returns the running transcript as a list of word dictionaries, \n        where each word dict has he following keys:\n        \n           'text' (str) -- the text of the word\n           'score' (float) -- the probability of the word\n           'start_time' (int) -- the start time of the word (in timesteps)\n           'end_time' (int) -- the end time of the word (in timesteps)\n           \n        Note that the start/end times are transformed from timestamps into\n        seconds by the ASR engine after CTCDecoder.decode() is called.\n        \"\"\"\n        text = []\n        prob = 1.0\n        probs = []\n        \n        # select the chars with the max probability\n        for i in range(logits.shape[0]):\n            argmax = np.argmax(logits[i])\n            text.append(self.vocab[argmax])\n            probs.append(logits[i][argmax])\n              \n        if global_config.debug:\n            print(text)\n            \n        # get the max number of sequential silent timesteps (continuing from last frame)\n        silent_timesteps = self.end_silent_timesteps\n        max_silent_timesteps = 0\n        \n        for i in range(len(text)):\n            if text[i] == '_':\n                silent_timesteps += 1\n            else:\n                max_silent_timesteps = max(silent_timesteps, max_silent_timesteps) if i > 0 else 0\n                silent_timesteps = 0\n        \n        if text[-1] == '_':\n            self.end_silent_timesteps = silent_timesteps\n           \n        # merge repeating chars and blank symbols\n        _, words = self.merge_chars(text, probs)  #text[:len(text)-self.config['offset']]\n        \n        # merge new words with past words\n        words = merge_words(self.words, words, self.config['word_threshold'], 'overlap')\n        \n        # increment timestep (after this frame's timestep is done being used, and before a potential EOS reset)\n        self.timestep += self.timestep_delta\n        \n        # check for EOS\n        end = False\n        \n        if silent_timesteps > self.timesteps_silence:\n            end = True\n            self.reset()\n        else:\n            self.words = words\n            \n        return [{\n            'text' : transcript_from_words(words, scores=global_config.debug, times=global_config.debug, end=end, add_punctuation=self.config['add_punctuation']),\n            'words' : words,\n            'end' : end\n        }]\n           \n    def merge_chars(self, text, probs):\n        \"\"\"\n        Merge repeating chars and blank symbols into words.\n        \"\"\"\n        text_merged = ''\n        \n        word = None\n        words = []\n\n        def ispunct(ch):\n            return ch in (string.punctuation + ' ')\n            \n        for i in range(len(text)):\n            if text[i] != self.prev_char and text[i] != '_':\n                self.prev_char = text[i]\n                \n                if text[i] != '_':\n                    text_merged += text[i]\n\n                    if not ispunct(text[i]):\n                        if word is None:\n                            word = {\n                                'text' : text[i],\n                                'score' : probs[i],\n                                'start_char' : len(text_merged) - 1,\n                                'end_char' : len(text_merged),\n                                'start_time' : self.timestep + i,\n                                'end_time' : self.timestep + i + 1\n                            }\n                        else:\n                            word['text'] += text[i]\n                            word['score'] *= probs[i]\n                            word['end_char'] = len(text_merged)\n                            word['end_time'] = self.timestep + i + 1\n    \n                if ispunct(text[i]) and word is not None:\n                    words.append(word)\n                    word = None\n            \n        if word is not None:\n            words.append(word)\n                \n        return text_merged, words\n        \n    def reset(self):\n        \"\"\"\n        Reset the CTC decoder state at EOS (end of sentence)\n        \"\"\"\n        self.prev_char = ''\n        self.end_silent_timesteps = 0\n        self.timestep = 0\n        self.words = []\n\n "
  },
  {
    "path": "jetson_voice/models/asr/ctc_utils.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nimport nltk\nimport numpy as np\n\nfrom jetson_voice.utils import global_config\n\n\ndef transcript_from_words(words, scores=False, times=False, end=False, add_punctuation=True):\n    \"\"\"\n    Convert a list of words to the text transcript.\n    \"\"\"\n    transcript = ''\n    \n    for idx, word in enumerate(words):\n    \n        if scores and times:\n            transcript += f\"{word['text']} ({word['start_time']}:{word['end_time']} {word['score']:.2f})\"\n        elif scores:\n            transcript += f\"{word['text']} ({word['score']:.2f})\"\n        elif times:\n            transcript += f\"{word['text']} ({word['start_time']}:{word['end_time']})\"\n        else:\n            transcript += word['text']\n        \n        if idx < len(words) - 1:\n            transcript += ' '\n      \n    if end and add_punctuation:\n        transcript += '.'  # add punctuation to end\n      \n    return transcript\n        \n\ndef find_overlapping_word(wordlist, word):\n    \"\"\"\n    Find the first word from the list with overlapping times.\n    Returns a (word, index) tuple or (None, -1) if no overlap found.\n    \"\"\"\n    for idx, word2 in enumerate(wordlist):\n        if not (word['end_time'] < word2['start_time'] or word['start_time'] > word2['end_time']):\n            return word2, idx \n    return None, -1\n\n\ndef find_word_after(wordlist, time):\n    \"\"\"\n    Find the nearest word that starts after the time.\n    Returns a (word, index) tuple or (None, 1) if all words start before the time.\n    \"\"\"\n    if isinstance(time, tuple):\n        time = time[1]  # use the end time\n        \n    for idx, word in enumerate(wordlist):\n        if time <= word['start_time']:\n            return word, idx        \n            \n    return None, -1\n\n\ndef find_word_before(wordlist, time):\n    \"\"\"\n    Find the nearest word that starts after the time.\n    Returns a (word, index) tuple or (None, 1) if all words start after the time.\n    \"\"\"\n    if isinstance(time, tuple):\n        time = time[0]  # use the start time\n        \n    for idx in range(len(wordlist)-1, -1, -1):\n        if time >= wordlist[idx]['end_time']:\n            return wordlist[idx], idx    \n            \n    return None, -1\n\n\ndef merge_words(wordlist, words, score_threshold=-np.inf, method='overlap'):\n    \"\"\"\n    Merge new words with past words.  This works by finding overlapping or similar words,\n    and replacing the old word with new word if the new word has a higher probability.\n    \"\"\"\n    if len(words) == 0:\n        return wordlist\n        \n    if len(wordlist) == 0:\n        return words\n        \n    # short-circuit if these are all new words    \n    if words[0]['start_time'] > wordlist[-1]['end_time']:\n        wordlist.extend(words)\n        return wordlist\n         \n    if method == 'overlap':\n        # find words that overlap and pick the highest-scoring one\n        for word in words:\n            if word['score'] < score_threshold: #self.config['word_threshold']:\n                continue\n                \n            if len(wordlist) == 0 or word['start_time'] > wordlist[-1]['end_time']:\n                wordlist.append(word)\n                continue\n\n            overlap_word, overlap_idx = find_overlapping_word(wordlist, word)\n            \n            if overlap_word is None:\n                continue\n\n            if global_config.debug:\n                print(f\"found new '{word['text']}' ({word['start_time']}:{word['end_time']} {word['score']:.2f}) overlaps with '{overlap_word['text']}' ({overlap_word['start_time']}:{overlap_word['end_time']} {overlap_word['score']:.2f})\")\n\n            if word['score'] > overlap_word['score']:\n                wordlist[overlap_idx] = word\n                \n    elif method == 'similarity':\n        # find the most-similar past word to the first new word\n        similarity_metric = np.inf #1000\n        similarity_index = -1\n        \n        for idx in range(len(wordlist)-1, -1, -1):  # search in reverse so words early in the transcript aren't matched first\n            similarity = nltk.edit_distance(words[0]['text'], wordlist[idx]['text'])\n            \n            if similarity < similarity_metric:\n                similarity_metric = similarity\n                similarity_index = idx\n                \n            if similarity == 0:\n                break\n           \n        if global_config.debug:\n            print(f\"closest word to '{words[0]['text']}' is '{wordlist[similarity_index]['text']}' (similarity={similarity_metric}) \")\n        \n        wordlist = wordlist[:similarity_index]\n        wordlist.extend(words)\n        \n    else:\n        raise ValueError(f\"invalid method '{method}' (valid options are 'overlap', 'similarity')\")\n        \n    return wordlist\n        \n        \ndef split_words(wordlist, times):\n    \"\"\"\n    Split the word list by the given times.\n    note - these times should be sorted\n    \"\"\"\n    wordlists = []\n\n    for time in times:\n        _, idx = find_word_after(wordlist, time)\n        \n        if idx < 0:\n            wordlists.append(wordlist)\n            return wordlists\n            \n        wordlists.append(wordlist[:idx])\n        wordlist = wordlist[idx:]\n        \n    wordlists.append(wordlist)    \n    return wordlists\n        \n        \ndef rebase_word_times(wordlist):\n    \"\"\"\n    Re-base the word timings so that the start of the first word is zero.\n    \"\"\"\n    if len(wordlist) == 0:\n        return wordlist\n        \n    #wordlist = wordlist.copy()\n    start_offset = wordlist[0]['start_time']\n            \n    for idx in range(len(wordlist)):\n        wordlist[idx]['start_time'] -= start_offset\n        wordlist[idx]['end_time'] -= start_offset\n    \n    return wordlist\n\n\ndef find_silent_intervals(logits, blank_symbol_id, min_silent_time, time_offset):\n    \"\"\"\n    Find blank/silent regions in the output logits.\n    \"\"\"\n    num_timesteps = logits.shape[0]\n    silent_intervals = []\n    last_interval_start = None\n    \n    for i in range(num_timesteps):\n        argmax = np.argmax(logits[i])\n        \n        if argmax == blank_symbol_id:\n            if last_interval_start is None:\n                last_interval_start = i \n        \n        if last_interval_start is not None and (argmax != blank_symbol_id or (i == num_timesteps-1)):\n            if i - last_interval_start >= min_silent_time:\n                silent_intervals.append((last_interval_start + time_offset, i-1+time_offset))\n            #    print(f'     new silent interval ({last_interval_start + self.timestep}:{i-1+self.timestep}) {i - last_interval_start} > {min_length:.2f}')  \n            #else:\n            #    print(f'skipping silent interval ({last_interval_start + self.timestep}:{i-1+self.timestep}) {i - last_interval_start} < {min_length:.2f}')\n                \n            last_interval_start = None\n\n    return silent_intervals\n        \n"
  },
  {
    "path": "jetson_voice/models/nlp/__init__.py",
    "content": "#!/usr/bin/env python3\r\n# coding: utf-8\r\n\r\nfrom .intent_slot import IntentSlotEngine\r\nfrom .question_answer import QuestionAnswerEngine\r\nfrom .text_classification import TextClassificationEngine\r\nfrom .token_classification import TokenClassificationEngine"
  },
  {
    "path": "jetson_voice/models/nlp/intent_slot.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nimport os\nimport logging\nimport numpy as np\n\nfrom transformers import AutoTokenizer\n\nfrom jetson_voice.nlp import IntentSlotService\nfrom jetson_voice.utils import load_model, normalize_logits\nfrom .nlp_utils import find_subtokens, nlp_dynamic_shapes\n\n\nclass IntentSlotEngine(IntentSlotService):\n    \"\"\"\n    Joint Intent and Slot classification model in TensorRT / onnxruntime.\n    \"\"\"\n    def __init__(self, config, *args, **kwargs):\n        \"\"\"\n        Load an Intent/Slot classification model from ONNX\n        \"\"\"\n        super(IntentSlotEngine, self).__init__(config, *args, **kwargs)\n\n        if self.config.type != 'intent_slot':\n            raise ValueError(f\"{self.config.model_path} isn't an Intent/Slot model (type '{self.config.type}'\")\n            \n        # load model\n        dynamic_shapes = {'max' : (1, self.config['language_model']['max_seq_length'])}  # (batch_size, sequence_length)\n        \n        if nlp_dynamic_shapes:\n            dynamic_shapes['min'] = (1, 1)\n        \n        self.model = load_model(self.config, dynamic_shapes)\n        \n        # create tokenizer\n        self.tokenizer = AutoTokenizer.from_pretrained(self.config['tokenizer']['tokenizer_name'])\n        self.null_slot = self.slot_labels[-1]  # 'O' in assistant dataset - always the last label?\n        \n        \n    def __call__(self, query):\n        \"\"\"\n        Perform intent/slot classification on the input query.\n        \n        Parameters:\n          query (string) -- The text query, for example:\n                             'What is the weather in San Francisco tomorrow?'\n\n        Returns a dict with the following keys:\n             'intent' (string) -- the classified intent label\n             'score' (float) -- the intent probability [0,1]\n             'slots' (list[dict]) -- a list of dicts, where each dict has the following keys:\n                  'slot' (string) -- the slot label\n                  'text' (string) -- the slot text from the query\n                  'score' (float) -- the slot probability [0,1]\n        \"\"\"\n        encodings = self.tokenizer(\n            text=query,\n            padding='longest' if nlp_dynamic_shapes else 'max_length',\n            truncation=True,\n            max_length=self.config['language_model']['max_seq_length'],\n            return_tensors='np',\n            return_token_type_ids=True,\n            return_overflowing_tokens=True,\n            return_offsets_mapping=True,\n            return_special_tokens_mask=True,\n        )\n\n        # during slot classification, we want to ignore slots from subtokens and special tokens \n        subtoken_mask = find_subtokens(encodings, method='subtoken_delimiters')\n        ignore_mask = subtoken_mask | encodings['special_tokens_mask']\n    \n        # retrieve the inputs from the encoded tokens\n        inputs = {}\n        \n        for input in self.model.inputs:\n            if input.name not in encodings:\n                raise ValueError(f\"the encoded inputs from the tokenizer doesn't contain '{input.name}'\")\n\n            inputs[input.name] = encodings[input.name]\n                    \n        # run the model\n        intent_logits, slot_logits = self.model.execute(inputs)\n\n        intent_logits = normalize_logits(intent_logits)\n        slot_logits = normalize_logits(slot_logits)\n\n        intent_preds = np.argmax(intent_logits, axis=-1)\n        slot_preds = np.argmax(slot_logits, axis=-1)\n\n        # convert numerical outputs to intent/slot labels\n        results = []\n\n        for query_idx, intent_id in enumerate(intent_preds):\n            results.append({\n                'intent' : self.intent_label(intent_id),\n                'score' : intent_logits[query_idx][intent_id],\n                'slots' : []\n            })\n                \n        for query_idx, slots in enumerate(slot_preds):\n            query_slots = [self.slot_label(slot) for slot in slots]\n\n            for token_idx, slot in enumerate(query_slots):\n                # ignore unclassified slots or masked tokens\n                if slot == self.null_slot or ignore_mask[query_idx][token_idx]:\n                    continue\n                    \n                # convert from token index back to the query string\n                chars = encodings.token_to_chars(query_idx, token_idx)\n                text = query[chars[0]:chars[1]]      # queries[query_idx]\n                \n                # append subtokens from the query to the text\n                for subtoken_idx in range(token_idx+1, len(query_slots)):\n                    if subtoken_mask[query_idx][subtoken_idx]:\n                        subtoken_chars = encodings.token_to_chars(query_idx, subtoken_idx)\n                        text += query[subtoken_chars[0]:subtoken_chars[1]]\n                    else:\n                        break\n                        \n                results[query_idx]['slots'].append({\n                    'slot' : slot,\n                    'text' : text,\n                    'score' : slot_logits[query_idx][token_idx][slots[token_idx]]\n                })\n        \n        if len(results) == 1:\n            return results[0]\n        else:\n            return results\n            \n    @property\n    def intent_labels(self):\n        \"\"\"\n        List of the intent class labels.\n        \"\"\"\n        return self.config['data_desc']['intent_labels']\n    \n    def intent_label(self, index):\n        \"\"\"\n        Return an intent label by index (with bounds checking)\n        \"\"\"\n        return self.intent_labels[int(index)] if index < len(self.intent_labels) else 'Unknown_Intent'\n        \n    @property\n    def slot_labels(self):\n        \"\"\"\n        List of the slot class labels.\n        \"\"\"\n        return self.config['data_desc']['slot_labels']\n    \n    def slot_label(self, index):\n        \"\"\"\n        Return a slot label by index (with bounds checking)\n        \"\"\"\n        return self.slot_labels[int(index)] if index < len(self.slot_labels) else self.null_slot\n        "
  },
  {
    "path": "jetson_voice/models/nlp/nlp_utils.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nimport numpy as np\n\n\n# NLP BERT models (and BERT derivatives) have myelin problem with dynamic shapes on aarch64,\n# so we disable dynamic shape changing for now (shapes will be set to the max sequence length)\nnlp_dynamic_shapes=False\n\n\ndef find_subtokens(encodings, method='char_span'):\n    \"\"\"\n    Compute the subtoken mask, where each token is marked as True if it's a subtoken or False otherwise.\n    Longer words/acronyms may be tokenized into mulitple word pieces (called subtokens), for example:\n    \n        'Yosemite' -> ['yo', '##se', '##mite']\n        'U.S.' -> ['u', '.', 's', '.']\n    \n    Parameters:\n      encodings (BatchEncoding) -- Output from tokenizer\n      \n      method (string) -- If 'char_span', the subtoken mask will be determined by looking at the character\n                         indices.  Tokens that map to characters that are side-by-side are flagged as subtokens.\n                         \n                         If 'subtoken_delimiters', subtokens will be identified by looking for '##' symbols.\n                         However this can miss punctuated subtokens, such as 'U.S.'\n    \n    Returns boolean subtoken mask array with shape (num_queries, num_tokens)\n    \"\"\"\n    num_queries = encodings['input_ids'].shape[0]\n    subtoken_mask = []\n    \n    if method == 'char_span':\n        for query_idx in range(num_queries):\n            mask = []\n            last_char = -1\n            tokens = encodings.tokens(query_idx)\n            \n            for token_idx, word_id in enumerate(encodings.word_ids(query_idx)):\n                if word_id is None:  # skip special tokens\n                    mask.append(False)\n                    continue\n                    \n                chars = encodings.token_to_chars(query_idx, token_idx)\n                \n                if chars[0] == last_char:\n                    mask.append(True)\n                else:\n                    mask.append(False)\n                    \n                last_char = chars[1]\n\n            subtoken_mask.append(mask)\n            \n    elif method == 'subtoken_delimiters':\n        for query_idx in range(num_queries):\n            subtoken_mask.append([token.startswith('##') for token in encodings.tokens(query_idx)])\n    else:\n        raise ValueError(f\"invalid method ('{method}')\")\n        \n    return np.asarray(subtoken_mask)\n        "
  },
  {
    "path": "jetson_voice/models/nlp/question_answer.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nimport os\nimport logging\nimport numpy as np\n\nfrom transformers import AutoTokenizer\n\nfrom jetson_voice.nlp import QuestionAnswerService\nfrom jetson_voice.utils import load_model, normalize_logits\nfrom .nlp_utils import nlp_dynamic_shapes\n\n\nclass QuestionAnswerEngine(QuestionAnswerService):\n    \"\"\"\n    Question answering model in TensorRT / onnxruntime.\n    \"\"\"\n    def __init__(self, config, *args, **kwargs):\n        \"\"\"\n        Load an question answering model from ONNX\n        \"\"\"\n        super(QuestionAnswerEngine, self).__init__(config, *args, **kwargs)\n\n        if self.config.type != 'qa':\n            raise ValueError(f\"{self.config.model_path} isn't a Question Answering model (type '{self.config.type}'\")\n            \n        # load model\n        dynamic_shapes = {'max' : (1, self.config['dataset']['max_seq_length'])}  # (batch_size, sequence_length)\n        \n        if nlp_dynamic_shapes:\n            dynamic_shapes['min'] = (1, 1)\n        \n        self.model = load_model(self.config, dynamic_shapes)\n        \n        # create tokenizer\n        self.tokenizer = AutoTokenizer.from_pretrained(self.config['tokenizer']['tokenizer_name'])\n        self.question_first = bool(self.tokenizer.padding_side == \"right\")\n        \n        \n    def __call__(self, query, top_k=1):\n        \"\"\"\n        Perform question/answering on the input query.\n        \n        Parameters:\n          query (dict or tuple) -- Either a dict with 'question' and 'context' keys,\n                                   or a (question, context) tuple.\n          top_k (int) -- How many of the top results to return, sorted by score.\n                         The default (top_k=1) is to return just the top result.\n                         If top_k > 1, then a list of results will be returned.\n          \n        Returns:\n          dict(s) with the following keys:\n          \n             'answer' (string) -- the answer text\n             'score' (float) -- the probability [0,1]\n             'start' (int) -- the starting character index of the answer into the context text\n             'end' (int) -- the ending character index of the answer into the context text\n             \n          If top_k > 1, a list of dicts with the top_k results will be returned.\n          If top_k == 1, just the single dict with the top score will be returned.\n        \"\"\"\n        if isinstance(query, dict):\n            question = query['question']\n            context = query['context']\n        elif isinstance(query, tuple):\n            question = query[0]\n            context = query[1]\n        else:\n            raise ValueError(f'query must be a dict or tuple (instead was type {type(query).__name__})')\n\n        # check for models that have a doc_stride >= max_seq_length\n        # this will cause an exception in the tokenizer\n        doc_stride = self.config['dataset']['doc_stride']\n        max_seq_len = self.config['dataset']['max_seq_length']\n        \n        if doc_stride >= max_seq_len:\n            doc_stride = int(max_seq_len/2)\n            \n        # tokenize the inputs\n        encodings = self.tokenizer(\n            text=question if self.question_first else context,\n            text_pair=context if self.question_first else question_text,\n            padding='longest' if nlp_dynamic_shapes else 'max_length',\n            truncation=\"only_second\" if self.question_first else \"only_first\",\n            max_length=max_seq_len,\n            stride=doc_stride,\n            return_tensors='np',\n            return_token_type_ids=True,\n            return_overflowing_tokens=True,\n            return_offsets_mapping=True,\n            return_special_tokens_mask=True,\n        )\n        \n        # When the input is too long, it's converted in a batch of inputs with overflowing tokens\n        # and a stride of overlap between the inputs. If a batch of inputs is given, a special output\n        # \"overflow_to_sample_mapping\" indicate which member of the encoded batch belong to which original batch sample.\n        # Here we tokenize examples one-by-one so we don't need to use \"overflow_to_sample_mapping\".\n        # \"num_span\" is the number of output samples generated from the overflowing tokens.\n        num_spans = len(encodings[\"input_ids\"])\n        logging.debug(f'num_spans: {num_spans}')\n\n        # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)\n        # We put 0 on the tokens from the context and 1 everywhere else (question and special tokens)\n        p_mask = np.asarray(\n            [\n                [tok != 1 if self.question_first else 0 for tok in encodings.sequence_ids(span_id)]\n                for span_id in range(num_spans)\n            ]\n        )\n\n        # keep the cls_token unmasked (some models use it to indicate unanswerable questions)\n        if self.tokenizer.cls_token_id is not None:\n            cls_index = np.nonzero(encodings[\"input_ids\"] == self.tokenizer.cls_token_id)\n            p_mask[cls_index] = 0\n            \n        # run the model over each span (TODO batching)\n        model_outputs = []\n        \n        for span_idx in range(num_spans):\n            inputs = {}\n            \n            for input in self.model.inputs:\n                if input.name not in encodings:\n                    raise ValueError(f\"the encoded inputs from the tokenizer doesn't contain '{input.name}'\")\n\n                inputs[input.name] = np.expand_dims(encodings[input.name][span_idx], axis=0) # add batch dim\n\n            model_outputs.append(self.model.execute(inputs))\n            \n        # post-processing\n        answers = []\n        min_null_score = 1000000\n        handle_impossible_answer = self.config['dataset']['version_2_with_negative']\n        \n        for span_idx in range(num_spans):\n            start_logits = np.squeeze(model_outputs[span_idx][:,:,0])\n            end_logits = np.squeeze(model_outputs[span_idx][:,:,1])\n\n            # Ensure padded tokens & question tokens cannot belong to the set of candidate answers.\n            undesired_tokens = np.abs(p_mask[span_idx] - 1) & encodings['attention_mask'][span_idx]\n\n            # Generate mask\n            undesired_tokens_mask = (undesired_tokens == 0.0)\n\n            # Make sure non-context indexes in the tensor cannot contribute to the softmax\n            start_logits = np.where(undesired_tokens_mask, -10000.0, start_logits)\n            end_logits = np.where(undesired_tokens_mask, -10000.0, end_logits)\n\n            # Normalize logits and spans to retrieve the answer\n            start_logits = np.exp(start_logits - np.log(np.sum(np.exp(start_logits), axis=-1, keepdims=True)))\n            end_logits = np.exp(end_logits - np.log(np.sum(np.exp(end_logits), axis=-1, keepdims=True)))\n\n            if handle_impossible_answer:\n                min_null_score = min(min_null_score, (start_logits[0] * end_logits[0]).item())\n\n            # Mask CLS\n            start_logits[0] = end_logits[0] = 0.0\n\n            # Decode token probabilities\n            starts, ends, scores = self.decode(start_logits, end_logits, top_k=top_k)\n\n            if self.tokenizer.is_fast:\n                # Convert the answer (tokens) back to the original text\n                # Score: score from the model\n                # Start: Index of the first character of the answer in the context string\n                # End: Index of the character following the last character of the answer in the context string\n                # Answer: Plain text of the answer\n                enc = encodings[span_idx]\n                \n                # Sometimes the max probability token is in the middle of a word so:\n                # - we start by finding the right word containing the token with `token_to_word`\n                # - then we convert this word in a character span with `word_to_chars`\n                for s, e, score in zip(starts, ends, scores):\n                    start = enc.word_to_chars(enc.token_to_word(s), sequence_index=1 if self.question_first else 0)[0]\n                    end = enc.word_to_chars(enc.token_to_word(e), sequence_index=1 if self.question_first else 0)[1]\n                    \n                    answers.append({\n                        'answer' : context[start : end],\n                        'score' : score.item(),\n                        'start' : start,\n                        'end' : end\n                    })\n            else:\n                raise NotImplementedError('QA post-processing is only implemented for fast tokenizers')\n            \n        if handle_impossible_answer:\n            answers.append({'answer': '', 'score': min_null_score, 'start': 0, 'end': 0})\n\n        answers = sorted(answers, key=lambda x: x['score'], reverse=True)[:top_k]\n        \n        if top_k == 1:\n            return answers[0]\n        else:\n            return answers\n\n\n    def decode(self, start: np.ndarray, end: np.ndarray, top_k: int):\n        \"\"\"\n        Take the QA model output and will generate probabilities for each span to be the actual answer.\n        In addition, it filters out some unwanted/impossible cases like answer len being greater than max_answer_len or\n        answer end position being before the starting position. The method supports output the k-best answer through\n        the top_k argument.\n        Args:\n            start (:obj:`np.ndarray`): Individual start probabilities for each token.\n            end (:obj:`np.ndarray`): Individual end probabilities for each token.\n            top_k (:obj:`int`): Indicates how many possible answer span(s) to extract from the model output.\n            max_answer_len (:obj:`int`): Maximum size of the answer to extract from the model's output.\n        \"\"\"\n        # Ensure we have batch axis\n        if start.ndim == 1:\n            start = start[None]\n\n        if end.ndim == 1:\n            end = end[None]\n\n        # Compute the score of each tuple(start, end) to be the real answer\n        outer = np.matmul(np.expand_dims(start, -1), np.expand_dims(end, 1))\n\n        # Remove candidate with end < start and end - start > max_answer_len\n        candidates = np.tril(np.triu(outer), self.config['dataset']['max_answer_length'] - 1)\n\n        #  Inspired by Chen & al. (https://github.com/facebookresearch/DrQA)\n        scores_flat = candidates.flatten()\n        if top_k == 1:\n            idx_sort = [np.argmax(scores_flat)]\n        elif len(scores_flat) < top_k:\n            idx_sort = np.argsort(-scores_flat)\n        else:\n            idx = np.argpartition(-scores_flat, top_k)[0:top_k]\n            idx_sort = idx[np.argsort(-scores_flat[idx])]\n\n        start, end = np.unravel_index(idx_sort, candidates.shape)[1:]\n        return start, end, candidates[0, start, end] \n        "
  },
  {
    "path": "jetson_voice/models/nlp/text_classification.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nimport os\nimport logging\nimport numpy as np\n\nfrom transformers import AutoTokenizer\n\nfrom jetson_voice.nlp import TextClassificationService\nfrom jetson_voice.utils import load_model, normalize_logits\nfrom .nlp_utils import nlp_dynamic_shapes\n\n\nclass TextClassificationEngine(TextClassificationService):\n    \"\"\"\n    Text classification model in TensorRT / onnxruntime.\n    \"\"\"\n    def __init__(self, config, *args, **kwargs):\n        \"\"\"\n        Load an text classification model from ONNX\n        \"\"\"\n        super(TextClassificationEngine, self).__init__(config, *args, **kwargs)\n\n        if self.config.type != 'text_classification':\n            raise ValueError(f\"{self.config.model_path} isn't a Text Classification model (type '{self.config.type}'\")\n            \n        # load model\n        dynamic_shapes = {'max' : (1, self.config['dataset']['max_seq_length'])}  # (batch_size, sequence_length)\n        \n        if nlp_dynamic_shapes:\n            dynamic_shapes['min'] = (1, 1)\n        \n        self.model = load_model(self.config, dynamic_shapes)\n        \n        # create tokenizer\n        self.tokenizer = AutoTokenizer.from_pretrained(self.config['tokenizer']['tokenizer_name'])\n        \n        \n    def __call__(self, query):\n        \"\"\"\n        Perform text classification on the input query.\n        \n        Parameters:\n          query (string) -- The text query, for example:\n                             'Today was warm, sunny and beautiful out.'\n\n        Returns a dict with the following keys:\n             'class' (int) -- the predicted class index\n             'label' (string) -- the predicted class label (and if there aren't labels `str(class)`)\n             'score' (float) -- the classification probability [0,1]\n        \"\"\"\n        encodings = self.tokenizer(\n            text=query,\n            padding='longest' if nlp_dynamic_shapes else 'max_length',\n            truncation=True,\n            max_length=self.config['dataset']['max_seq_length'],\n            return_tensors='np',\n            return_token_type_ids=True,\n            return_overflowing_tokens=True,\n            return_offsets_mapping=True,\n            return_special_tokens_mask=True,\n        )\n    \n        # retrieve the inputs from the encoded tokens\n        inputs = {}\n        \n        for input in self.model.inputs:\n            if input.name not in encodings:\n                raise ValueError(f\"the encoded inputs from the tokenizer doesn't contain '{input.name}'\")\n\n            inputs[input.name] = encodings[input.name]\n                    \n        # run the model\n        logits = self.model.execute(inputs)\n        logits = normalize_logits(logits)\n        preds  = np.argmax(logits, axis=-1)\n \n        # tabulate results\n        results = []\n        \n        for query_idx in range(preds.shape[0]):\n            results.append({\n                'class' : int(preds[query_idx]),\n                'label' : str(preds[query_idx]),\n                'score' : logits[query_idx][preds[query_idx]]\n            })\n            \n        if len(results) == 1:\n            return results[0]\n        else:\n            return results\n        "
  },
  {
    "path": "jetson_voice/models/nlp/token_classification.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nimport os\nimport logging\nimport numpy as np\n\nfrom transformers import AutoTokenizer\n\nfrom jetson_voice.nlp import TokenClassificationService\nfrom jetson_voice.utils import load_model, normalize_logits\nfrom .nlp_utils import find_subtokens, nlp_dynamic_shapes\n\n\nclass TokenClassificationEngine(TokenClassificationService):\n    \"\"\"\n    Token classification model (aka Named Entity Recognition) in TensorRT / onnxruntime.\n    \"\"\"\n    def __init__(self, config, *args, **kwargs):\n        \"\"\"\n        Load an token classification model for NER from ONNX\n        \"\"\"\n        super(TokenClassificationEngine, self).__init__(config, *args, **kwargs)\n\n        if self.config.type != 'token_classification':\n            raise ValueError(f\"{self.config.model_path} isn't a Token Classification model (type '{self.config.type}'\")\n            \n        # load model\n        dynamic_shapes = {'max' : (1, self.config['dataset']['max_seq_length'])}  # (batch_size, sequence_length)\n        \n        if nlp_dynamic_shapes:\n            dynamic_shapes['min'] = (1, 1)\n        \n        self.model = load_model(self.config, dynamic_shapes)\n        \n        # create tokenizer\n        self.tokenizer = AutoTokenizer.from_pretrained(self.config['tokenizer']['tokenizer_name'])\n        \n        \n    def __call__(self, query):\n        \"\"\"\n        Perform token classification (NER) on the input query and return tagged entities.\n        \n        Parameters:\n          query (string) -- The text query, for example:\n                             \"Ben is from Chicago, a city in the state of Illinois, US'\n\n        Returns a list[dict] of tagged entities with the following dictionary keys:\n             'class' (int) -- the entity class index\n             'label' (string) -- the entity class label\n             'score' (float) -- the classification probability [0,1]\n             'text'  (string) -- the corresponding text from the input query\n             'start' (int) -- the starting character index of the text\n             'end'   (int) -- the ending character index of the text\n        \"\"\"\n        encodings = self.tokenizer(\n            text=query,\n            padding='longest' if nlp_dynamic_shapes else 'max_length',\n            truncation=True,\n            max_length=self.config['dataset']['max_seq_length'],\n            return_tensors='np',\n            return_token_type_ids=True,\n            return_overflowing_tokens=True,\n            return_offsets_mapping=True,\n            return_special_tokens_mask=True,\n        )\n    \n        # during token classification, we want to ignore slots from subtokens and special tokens \n        subtoken_mask = find_subtokens(encodings)\n        ignore_mask = subtoken_mask | encodings['special_tokens_mask']\n        \n        # retrieve the inputs from the encoded tokens\n        inputs = {}\n        \n        for input in self.model.inputs:\n            if input.name not in encodings:\n                raise ValueError(f\"the encoded inputs from the tokenizer doesn't contain '{input.name}'\")\n\n            inputs[input.name] = encodings[input.name]\n                    \n        # run the model\n        logits = self.model.execute(inputs)\n        logits = normalize_logits(logits)\n        \n        preds = np.argmax(logits, axis=-1)\n        probs = np.amax(logits, axis=-1)\n        \n        # tabulate results\n        tags = []\n        label_map = {v: k for k, v in self.config['label_ids'].items()}\n        num_queries, num_tokens, _ = logits.shape\n        \n        assert num_queries == 1  # there should only be 1 input query currently\n        \n        for query_idx in range(num_queries):\n            query_tags = []\n            \n            for token_idx in range(num_tokens):\n                label = label_map[preds[query_idx][token_idx]]\n                \n                # ignore unclassified slots or masked tokens\n                if label == self.config['dataset']['pad_label'] or ignore_mask[query_idx][token_idx]:\n                    continue\n\n                # convert from token index back to the query string\n                chars = encodings.token_to_chars(query_idx, token_idx)\n                \n                # append subtokens from the query to the text\n                for subtoken_idx in range(token_idx+1, num_tokens):\n                    if subtoken_mask[query_idx][subtoken_idx]:\n                        chars = (chars[0], encodings.token_to_chars(query_idx, subtoken_idx)[1])\n                    else:\n                        break\n\n                text = query[chars[0]:chars[1]] # queries[query_idx]\n\n                # strip out punctuation to attach the entity tag to the word not to a punctuation mark\n                if not text[-1].isalpha():\n                    text = text[:-1]\n                    chars = (chars[0], chars[1]-1)\n                        \n                query_tags.append({\n                    'label' : label,\n                    'class' : preds[query_idx][token_idx],\n                    'score' : probs[query_idx][token_idx],\n                    'text' : text,\n                    'start' : chars[0],\n                    'end' : chars[1]\n                })\n                \n            tags.append(query_tags)\n            \n        if len(tags) == 1:\n            return tags[0]\n        else:\n            return tags\n        "
  },
  {
    "path": "jetson_voice/models/tts/__init__.py",
    "content": "#!/usr/bin/env python3\r\n# coding: utf-8\r\n\r\nfrom .tts_engine import TTSEngine\r\n"
  },
  {
    "path": "jetson_voice/models/tts/tts_engine.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nimport os\nimport re\nimport logging\nimport inflect\n\nimport numpy as np\n\nfrom jetson_voice.tts import TTSService\nfrom jetson_voice.utils import global_config, load_model, softmax\n\n      \nclass TTSEngine(TTSService):\n    \"\"\"\n    Text-to-speech synthesis.  This is actually a pipeline of two models,\n    the generator model (which generates MEL spectrograms from tokens),\n    and the vocoder (which outputs audio from MEL spectrograms)\n    \"\"\"\n    def __init__(self, config, *args, **kwargs):\n        \"\"\"\n        Loads a streaming ASR model from ONNX or serialized TensorRT engine.\n        \n        Parameters:\n          model (string) -- path to ONNX model or serialized TensorRT engine/plan\n          config (string) -- path to model configuration json (will be inferred from model path if empty)\n        \"\"\"\n        super(TTSEngine, self).__init__(config, *args, **kwargs)\n\n        if self.config.type != 'tts':\n            raise ValueError(f\"{self.config.model_path} isn't a Text-to-Speech model (type '{self.config.type}'\")\n            \n        # load text->MEL generator model\n        self.generator = load_model(self.config.generator)\n        \n        # load MEL->audio vocoder model\n        features = self.config.vocoder.features\n        \n        dynamic_shapes = {\n            'min' : (1, features, 1),\n            'opt' : (1, features, 160), # ~5-6 words\n            'max' : (1, features, 1024) # ~20-30 words?\n        }\n        \n        self.vocoder = load_model(self.config.vocoder, dynamic_shapes=dynamic_shapes)\n        \n        # create map of symbol->ID embeddings\n        self.symbol_to_id = {s: i for i, s in enumerate(self.get_symbols())}\n        \n        # create operators for num-to-word conversion\n        self.number_regex = re.compile(r'\\d+(?:,\\d+)?')  # https://stackoverflow.com/a/16321189\n        self.number_inflect = inflect.engine()\n        \n    def __call__(self, text):\n        \"\"\"\n        Generate audio from text.\n        \n        Parameters:\n          text (string) -- The phrase to convert to audio.\n\n        Returns audio samples in a numpy array.\n        \"\"\"\n        text = self.numbers_to_words(text)   # vocab doesn't include numbers, so convert them to words\n        \n        pad_symbol = ' '\n        min_length = 6\n        \n        if text[-1].isalnum():      # end with punctuation, otherwise audio is cut-off\n            text += pad_symbol\n          \n        if len(text) < min_length:  # WAR for cuDNN error on JetPack <= 4.5.x\n            text = text.ljust(min_length, pad_symbol)\n            \n        # convert chars to symbol embeddings\n        encoded_text = [self.symbol_to_id[s] for s in text.lower() if s in self.symbol_to_id]\n        encoded_text = np.expand_dims(np.array(encoded_text, dtype=np.int64), axis=0)\n        \n        # generate MEL spectrogram + audio\n        mels = self.generator.execute(encoded_text)[0]\n        audio = self.vocoder.execute(mels)\n\n        return audio.squeeze()\n     \n    def get_symbols(self):\n        \"\"\"\n        Return a list of all the accepted character symbols / embeddings\n        \"\"\"\n        _arpabet = [\n          'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2',\n          'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2',\n          'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY',\n          'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1',\n          'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0',\n          'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW',\n          'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH'\n        ]\n        _arpabet = ['@' + s for s in _arpabet]\n        _pad = '_'\n        _punctuation = '!\\'(),.:;? '\n        _special = '-'\n        _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'\n        symbols = list(_pad + _special + _punctuation + _letters) + _arpabet\n        return symbols\n     \n    def numbers_to_words(self, text):\n        \"\"\"\n        Convert instances of numbers to words in the text.\n        For example:  \"The answer is 42\" -> \"The answer is forty two.\"\n        \"\"\"\n        number_tokens = self.number_regex.findall(text)\n        \n        for number_token in number_tokens:\n            # TODO test/handle floating-point numbers\n            word_text = self.number_inflect.number_to_words(number_token)              \n            num_begin = text.index(number_token)\n\n            # insert the words back at the old location\n            text = text[:num_begin] + word_text + text[num_begin + len(number_token):]\n            \n        return text\n        \n    @property\n    def sample_rate(self):\n        \"\"\"\n        Get the output sample rate (e.g. 22050, 44100, ect)\n        \"\"\"\n        return self.config['vocoder']['sample_rate']"
  },
  {
    "path": "jetson_voice/nlp.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nfrom jetson_voice.utils import load_resource\n\n\ndef NLP(resource, *args, **kwargs):\n    \"\"\"\n    Factory for automatically loading NLP models or services.\n    \n    Returns an instance of:\n        - IntentSlotService\n        - QuestionAnswerService\n        - TextClassificationService\n        - TokenClassificationService\n    \"\"\"\n    from jetson_voice.auto import AutoModel\n    return AutoModel(resource, domain='nlp', *args, **kwargs)\n    \n    \ndef IntentSlot(resource, *args, **kwargs):\n    \"\"\"\n    Loads a NLP joint intent/slot classifier service or model.\n    See the IntentSlotService class for the signature that implementations use.\n    \"\"\"\n    factory_map = {\n        'tensorrt' : 'jetson_voice.models.nlp.IntentSlotEngine',\n        'onnxruntime' : 'jetson_voice.models.nlp.IntentSlotEngine'\n    }\n    \n    return load_resource(resource, factory_map, *args, **kwargs)\n\n    \nclass IntentSlotService():\n    \"\"\"\n    Intent/slot classifier service base class.\n    \"\"\"\n    def __init__(self, config, *args, **kwargs):\n        \"\"\"\n        Create service instance.\n        \"\"\"\n        self.config = config\n        \n    def __call__(self, query):\n        \"\"\"\n        Perform intent/slot classification on the input query.\n        \n        Parameters:\n          query (string) -- The text query, for example:\n                             'What is the weather in San Francisco tomorrow?'\n\n        Returns a dict with the following keys:\n             'intent' (string) -- the classified intent label\n             'score' (float) -- the intent probability [0,1]\n             'slots' (list[dict]) -- a list of dicts, where each dict has the following keys:\n                  'slot' (string) -- the slot label\n                  'text' (string) -- the slot text from the query\n                  'score' (float) -- the slot probability [0,1]\n        \"\"\"\n        pass\n\n \ndef QuestionAnswer(resource, *args, **kwargs):\n    \"\"\"\n    Loads a NLP question answering service or model.\n    See the QuestionAnswerService class for the signature that implementations use.\n    \"\"\"\n    factory_map = {\n        'tensorrt' : 'jetson_voice.models.nlp.QuestionAnswerEngine',\n        'onnxruntime' : 'jetson_voice.models.nlp.QuestionAnswerEngine'\n    }\n    \n    return load_resource(resource, factory_map, *args, **kwargs) \n        \n   \nclass QuestionAnswerService():\n    \"\"\"\n    Question answering service base class.\n    \"\"\"\n    def __init__(self, config, *args, **kwargs):\n        \"\"\"\n        Create service instance.\n        \"\"\"\n        self.config = config\n        \n    def __call__(self, query, top_k=1):\n        \"\"\"\n        Perform question/answering on the input query.\n        \n        Parameters:\n          query (dict or tuple) -- Either a dict with 'question' and 'context' keys,\n                                   or a (question, context) tuple.\n          top_k (int) -- How many of the top results to return, sorted by score.\n                         The default (topk=1) is to return just the top result.\n                         If topk > 1, then a list of results will be returned.\n          \n        Returns:\n          dict(s) with the following keys:\n          \n             'answer' (string) -- the answer text\n             'score' (float) -- the probability [0,1]\n             'start' (int) -- the starting character index of the answer into the context text\n             'end' (int) -- the ending character index of the answer into the context text\n             \n          If top_k > 1, a list of dicts with the topk results will be returned.\n          If top_k == 1, just the single dict with the top score will be returned.\n        \"\"\"\n        pass\n        \n\ndef TextClassification(resource, *args, **kwargs):\n    \"\"\"\n    Loads a NLP text classification service or model.\n    See the TextClassificationService class for the signature that implementations use.\n    \"\"\"\n    factory_map = {\n        'tensorrt' : 'jetson_voice.models.nlp.TextClassificationEngine',\n        'onnxruntime' : 'jetson_voice.models.nlp.TextClassificationEngine'\n    }\n    \n    return load_resource(resource, factory_map, *args, **kwargs) \n        \n   \nclass TextClassificationService():\n    \"\"\"\n    Text classification service base class.\n    \"\"\"\n    def __init__(self, config, *args, **kwargs):\n        \"\"\"\n        Create service instance.\n        \"\"\"\n        self.config = config\n        \n    def __call__(self, query):\n        \"\"\"\n        Perform text classification on the input query.\n        \n        Parameters:\n          query (string) -- The text query, for example:\n                             'Today was warm, sunny and beautiful out.'\n\n        Returns a dict with the following keys:\n             'class' (int) -- the predicted class index\n             'label' (string) -- the predicted class label (and if there aren't labels `str(class)`)\n             'score' (float) -- the classification probability [0,1]\n        \"\"\"\n        pass\n\n\ndef TokenClassification(resource, *args, **kwargs):\n    \"\"\"\n    Loads a NLP token classification (aka Named Entity Recognition) service or model.\n    See the TokenClassificationService class for the signature that implementations use.\n    \"\"\"\n    factory_map = {\n        'tensorrt' : 'jetson_voice.models.nlp.TokenClassificationEngine',\n        'onnxruntime' : 'jetson_voice.models.nlp.TokenClassificationEngine'\n    }\n    \n    return load_resource(resource, factory_map, *args, **kwargs) \n        \n   \nclass TokenClassificationService():\n    \"\"\"\n    Token classification (aka Named Entity Recognition) service base class.\n    \"\"\"\n    def __init__(self, config, *args, **kwargs):\n        \"\"\"\n        Create service instance.\n        \"\"\"\n        self.config = config\n        \n    def __call__(self, query):\n        \"\"\"\n        Perform token classification (NER) on the input query and return tagged entities.\n        \n        Parameters:\n          query (string) -- The text query, for example:\n                             \"Ben is from Chicago, a city in the state of Illinois, US'\n\n        Returns a list[dict] of tagged entities with the following dictionary keys:\n             'class' (int) -- the entity class index\n             'label' (string) -- the entity class label\n             'score' (float) -- the classification probability [0,1]\n             'text'  (string) -- the corresponding text from the input query\n             'start' (int) -- the starting character index of the text\n             'end'   (int) -- the ending character index of the text\n        \"\"\"\n        pass\n\n    @staticmethod\n    def tag_string(query, tags, scores=False):\n        \"\"\"\n        Returns a string with the tags inserted inline with the query.  For example:\n        \n        \"Ben[B-PER] is from Chicago[B-LOC], a city in the state of Illinois[B-LOC], US[B-LOC]\"\n        \n        Parameters:\n          query  (string) -- The original query string.\n          tags   (list[dict]) -- The tags predicted by the model.\n          scores (bool) -- If true, the probabilities will be added inline.\n                           If false (default), only the tag labels will be added.\n        \"\"\"\n        char_offset = 0\n\n        for tag in tags:\n            if scores:\n                tag_str = f\"[{tag['label']} {tag['score']:.3}]\"\n            else:\n                tag_str = f\"[{tag['label']}]\"\n                \n            query = query[:tag['end'] + char_offset] + tag_str + query[tag['end'] + char_offset:]\n            char_offset += len(tag_str)\n            \n        return query\n        \n        \nif __name__ == \"__main__\":\n\n    from jetson_voice import ConfigArgParser\n    import pprint\n    \n    parser = ConfigArgParser()\n    \n    parser.add_argument('--model', default='distilbert_intent', type=str)\n    parser.add_argument('--type', default='intent_slot', type=str)\n\n    args = parser.parse_args()\n    args.type = args.type.lower()\n    \n    print(args)\n    \n    if args.type == 'intent_slot':\n    \n        model = IntentSlot(args.model)\n        \n        # create some test queries\n        queries = [\n            'Set alarm for Seven Thirty AM',\n            'Please increase the volume',\n            'What is my schedule for tomorrow',\n            'Place an order for a large pepperoni pizza from Dominos'\n        ]\n\n        # process the queries\n        for query in queries:\n            results = model(query)\n            \n            print('\\n')\n            print('query:', query)\n            print('')\n            pprint.pprint(results)\n     \n    elif args.type == 'question_answer' or args.type == 'qa':\n\n        model = QuestionAnswer(args.model)\n        \n        # create some test queries\n        queries = []\n        \n        queries.append({\n            \"question\" : \"What is the value of Pi?\",\n            \"context\" : \"Some people have said that Pi is tasty but there should be a value for Pi, and the value for Pi is around 3.14. \"\n                        \"Pi is the ratio of a circle's circumference to it's diameter. The constant Pi was first calculated by Archimedes \"\n                        \"in ancient Greece around the year 250 BC.\"\n        })\n        \n        queries.append({\n            \"question\" : \"Who discovered Pi?\",\n            \"context\" : queries[-1]['context']\n        })\n\n        queries.append({\n            \"question\" : \"Which nation contains the majority of the Amazon forest?\",\n            \"context\" : \"The Amazon rainforest is a moist broadleaf forest that covers most of the Amazon basin of South America. \"\n                        \"This basin encompasses 7,000,000 square kilometres (2,700,000 sq mi), of which 5,500,000 square kilometres \"\n                        \"(2,100,000 sq mi) are covered by the rainforest. The majority of the forest is contained within Brazil, \"\n                        \"with 60% of the rainforest, followed by Peru with 13%, and Colombia with 10%.\"\n        })\n        \n        queries.append({\n            \"question\" : \"How large is the Amazon rainforest?\",\n            \"context\" : queries[-1]['context']\n        })\n        \n        # process the queries\n        for query in queries:\n            answers = model(query, top_k=5)\n            \n            print('\\n')\n            print('context:', query['context'])\n            print('')\n            print('question:', query['question'])\n            \n            for answer in answers:\n                print('')\n                print('answer:  ', answer['answer'])\n                print('score:   ', answer['score'])\n    \n    elif args.type == 'text_classification':\n    \n        model = TextClassification(args.model)\n        \n        # create some test queries (these are for sentiment models)\n        queries = [\n            \"By the end of no such thing the audience, like beatrice, has a watchful affection for the monster.\",\n            \"Director Rob Marshall went out gunning to make a great one.\",\n            \"Uneasy mishmash of styles and genres.\",\n            \"I love exotic science fiction / fantasy movies but this one was very unpleasant to watch. I gave it 4 / 10 since some special effects were nice.\",\n            \"Today was cold and rainy and not very nice.\",\n            \"Today was warm, sunny and beautiful out.\",\n        ]\n\n        # process the queries\n        for query in queries:\n            results = model(query)\n            print('\\nquery:', query)\n            pprint.pprint(results)\n    \n    elif args.type == 'token_classification':\n    \n        model = TokenClassification(args.model)\n    \n        # create some test queries\n        queries = [\n            \"But candidate Charles Baker, who has about eight percent of the vote, has called for an investigation into reports of people voting multiple times.\",\n            \"Analysts say Mr. Chung's comments may be part of efforts by South Korea to encourage North Korea to resume bilateral talks.\",\n            \"The 63-year-old Daltrey walked offstage during the first song; guitarist Pete Townshend later told the crowd he was suffering from bronchitis and could barely speak.\",\n            \"The Who is currently touring in support of Endless Wire, its first album since 1982.\",\n            \"Meanwhile, Iowa is cleaning up after widespread flooding inundated homes, destroyed crops and cut off highways and bridges.\",\n            \"At the White House Tuesday, U.S. President George Bush expressed concern for the flood victims.\",\n            \"Ben is from Chicago, a city in the state of Illinois, US with a population of 2.7 million people.\",\n            \"Lisa's favorite place to climb in the summer is El Capitan in Yosemite National Park in California, U.S.\"\n        ]\n\n        # process the queries\n        for query in queries:\n            tags = model(query)\n            #print(f'\\n{query}')\n            #pprint.pprint(tags)\n            print(f'\\n{model.tag_string(query, tags, scores=True)}')\n        \n    else: \n        raise ValueError(f\"invalid --type argument ({args.type})\")\n        "
  },
  {
    "path": "jetson_voice/tts.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nfrom jetson_voice.utils import load_resource\n\n\ndef TTS(resource, *args, **kwargs):\n    \"\"\"\n    Loads a TTS service or model.\n    See the TTSService class for the signature that implementations use.\n    \"\"\"\n    factory_map = {\n        'riva' : 'jetson_voice.backends.riva.RivaTTSService',\n        'tensorrt' : 'jetson_voice.models.tts.TTSEngine',\n        'onnxruntime' : 'jetson_voice.models.tts.TTSEngine'\n    }\n    \n    return load_resource(resource, factory_map, *args, **kwargs)\n\n    \nclass TTSService():\n    \"\"\"\n    TTS service base class.\n    \"\"\"\n    def __init__(self, config, *args, **kwargs):\n        \"\"\"\n        Create service instance.\n        \"\"\"\n        self.config = config\n        \n    def __call__(self, text):\n        \"\"\"\n        Generate audio from text.\n        \n        Parameters:\n          text (string) -- The phrase to convert to audio.\n\n        Returns audio samples in a numpy array.\n        \"\"\"\n        pass\n    \n    @property\n    def sample_rate(self):\n        \"\"\"\n        Get the output sample rate (in Hz)\n        \"\"\"\n        pass\n        \n        \nif __name__ == \"__main__\":\n\n    from jetson_voice import list_audio_devices, ConfigArgParser\n    from soundfile import SoundFile\n    \n    import pprint\n    import pyaudio\n    import time\n    \n    parser = ConfigArgParser()\n    \n    parser.add_argument('--model', default='fastpitch_hifigan', type=str)\n    parser.add_argument('--text', default='Hello, how are you today?', type=str)\n    parser.add_argument('--warmup', type=int, default=9, help='the number of warmup runs')\n    parser.add_argument(\"--output-device\", type=int, default=None, help='output audio device to use')\n    parser.add_argument(\"--output-wav\", type=str, default=None, help='output wav file to write to')\n    parser.add_argument('--list-devices', action='store_true', help='list audio input devices')\n    \n    args = parser.parse_args()\n    print(args)\n    \n    # list audio devices\n    if args.list_devices:\n        list_audio_devices()\n        \n    # load the model\n    tts = TTS(args.model)\n    \n     # display the text\n    print(f\"\\n'{args.text}'\\n\")\n    \n    # run the TTS\n    for run in range(args.warmup+1):\n        start = time.perf_counter()\n        audio = tts(args.text)\n        stop = time.perf_counter()\n        latency = stop-start\n        duration = audio.shape[0]/tts.sample_rate\n        print(f\"Run {run} -- Time to first audio: {latency:.3f}s. Generated {duration:.2f}s of audio. RTFx={duration/latency:.2f}.\")\n        \n    # output the audio\n    if args.output_device is not None:\n        p = pyaudio.PyAudio()\n        stream = p.open(output_device_index=args.output_device, \n                        format=pyaudio.paFloat32, \n                        channels=1, rate=tts.sample_rate, output=True)\n        stream.write(audio.tobytes())\n        stream.close_stream()\n        stream.close()\n        \n    if args.output_wav is not None:\n        wav = SoundFile(args.output_wav, mode='w', samplerate=tts.sample_rate, channels=1)\n        wav.write(audio)\n        wav.close()\n        print(f\"Wrote audio to {args.output_wav}\")\n    "
  },
  {
    "path": "jetson_voice/utils/__init__.py",
    "content": "#!/usr/bin/env python3\r\n# coding: utf-8\r\n\r\nfrom .config import global_config, ConfigDict, ConfigArgParser\r\nfrom .resource import find_resource, load_resource, load_model, list_models\r\n\r\nfrom .audio import *\r\nfrom .softmax import softmax, normalize_logits"
  },
  {
    "path": "jetson_voice/utils/audio.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nimport os\nimport math\nimport pprint\nimport logging\nimport librosa\nimport soundfile\n\nimport pyaudio as pa\nimport numpy as np\n\n\ndef audio_db(samples):\n    \"\"\"\n    Compute RMS of audio samples in dB.\n    \"\"\"\n    rms = librosa.feature.rms(y=samples, frame_length=samples.shape[0], center=False)[0][0]\n\n    if rms != 0.0:\n        return 20.0 * math.log10(rms)\n    else:\n        return -100.0\n        \n        \ndef audio_to_float(samples):\n    \"\"\"\n    Convert audio samples to 32-bit float in the range [-1,1]\n    \"\"\"\n    if samples.dtype == np.float32:\n        return samples\n        \n    return samples.astype(np.float32) / 32768\n  \n\ndef audio_to_int16(samples):\n    \"\"\"\n    Convert audio samples to 16-bit float in the range [-32767,32767]\n    \"\"\"\n    if samples.dtype == np.int16:\n        return samples\n    elif samples.dtype == np.float32:\n        return (samples * 32768).astype(np.int16)\n    else:\n        return samples.astype(np.int16)\n        \n    \ndef AudioInput(wav=None, mic=None, sample_rate=16000, chunk_size=16000):\n    \"\"\"\n    Create an audio input stream from wav file or microphone.\n    Either the wav or mic argument needs to be specified.\n    \n    Parameters:\n        wav (string) -- path to .wav file\n        mic (int) -- microphone device index\n        sample_rate (int) -- the desired sample rate in Hz\n        chunk_size (int) -- the number of samples returned per next() iteration\n        \n    Returns AudioWavStream or AudioMicStream\n    \"\"\"\n    if mic is not None and mic != '':\n        return AudioMicStream(mic, sample_rate=sample_rate, chunk_size=chunk_size)\n    elif wav is not None and wav != '':\n        return AudioWavStream(wav, sample_rate=sample_rate, chunk_size=chunk_size)\n    else:\n        raise ValueError('either wav or mic argument must be specified')\n \n \nclass AudioWavStream:\n    \"\"\"\n    Audio playback stream from .wav file\n    \"\"\"\n    def __init__(self, filename, sample_rate, chunk_size):\n        self.filename = filename\n        self.chunk_size = chunk_size\n        self.sample_rate = sample_rate\n                \n        if not os.path.isfile(filename):\n            raise IOError(f'could not find file {filename}')\n            \n        logging.info(f\"loading audio '{filename}'\")\n        \n        self.samples, _ = librosa.load(filename, sr=sample_rate, mono=True)\n        self.position = 0\n\n    def open(self):\n        pass\n        \n    def close(self):\n        pass\n        \n    def reset(self):\n        self.position = 0\n        \n    def next(self):\n        if self.position >= len(self.samples):\n            return None\n        \n        chunk = self.samples[self.position : min(self.position + self.chunk_size, len(self.samples))]\n        \n        if len(chunk) < self.chunk_size:\n            chunk = np.pad(chunk, (0, self.chunk_size-len(chunk)), mode='constant')\n            \n        self.position += self.chunk_size\n        return chunk\n        \n    def __next__(self):\n        samples = self.next()\n        \n        if samples is None:\n            raise StopIteration\n        else:\n            return samples\n        \n    def __iter__(self):\n        self.position = 0\n        return self\n\n\nclass AudioMicStream:\n    \"\"\"\n    Live audio stream from microphone input device.\n    \"\"\"\n    def __init__(self, device, sample_rate, chunk_size):\n        self.stream = None\n        self.interface = pa.PyAudio()\n        \n        self.device_info = find_audio_device(device, self.interface)\n        self.device_id = self.device_info['index']\n        self.device_sample_rate = sample_rate\n        self.device_chunk_size = chunk_size\n        \n        self.sample_rate = sample_rate\n        self.chunk_size = chunk_size\n        \n        print('Audio Input Device:')\n        pprint.pprint(self.device_info)\n    \n    def __del__(self):\n        self.close()\n        self.interface.terminate()\n        \n    def open(self):\n        if self.stream:\n            return\n        \n        sample_rates = [self.sample_rate, int(self.device_info['defaultSampleRate']), 16000, 22050, 32000, 44100]\n        chunk_sizes = []\n        \n        for sample_rate in sample_rates:\n            chunk_sizes.append(int(self.chunk_size * sample_rate / self.sample_rate))\n            \n        for sample_rate, chunk_size in zip(sample_rates, chunk_sizes):\n            try:    \n                logging.info(f'trying to open audio input {self.device_id} with sample_rate={sample_rate} chunk_size={chunk_size}')\n                \n                self.stream = self.interface.open(format=pa.paInt16,\n                                channels=1,\n                                rate=sample_rate,\n                                input=True,\n                                input_device_index=self.device_id,\n                                frames_per_buffer=chunk_size)\n                                \n                self.device_sample_rate = sample_rate\n                self.device_chunk_size = chunk_size\n                \n                break\n                \n            except OSError as err:\n                print(err)\n                logging.warning(f'failed to open audio input {self.device_id} with sample_rate={sample_rate}')\n                self.stream = None\n                \n        if self.stream is None:\n            logging.error(f'failed to open audio input device {self.device_id} with any of these sample rates:')\n            logging.error(str(sample_rates))\n            raise ValueError(f\"audio input device {self.device_id} couldn't be opened or does not support any of the above sample rates\")\n                      \n        print(f\"\\naudio stream opened on device {self.device_id} ({self.device_info['name']})\")\n        print(\"you can begin speaking now... (press Ctrl+C to exit)\\n\")\n            \n    def close(self):\n        if self.stream is not None:\n            self.stream.stop_stream()\n            self.stream.close()\n            self.stream = None\n     \n    def reset(self):\n        self.close()\n        self.open()\n        \n    def next(self):\n        self.open()\n            \n        samples = self.stream.read(self.device_chunk_size, exception_on_overflow=False)\n        samples = np.frombuffer(samples, dtype=np.int16)\n        \n        if self.sample_rate != self.device_sample_rate:\n            samples = audio_to_float(samples)\n            samples = librosa.resample(samples, self.device_sample_rate, self.sample_rate)\n            \n            if len(samples) != self.chunk_size:\n                logging.warning(f'resampled input audio has {len(samples)}, but expected {self.chunk_size} samples')\n                \n        return samples\n        \n    def __next__(self):\n        samples = self.next()\n        \n        if samples is None:\n            raise StopIteration\n        else:\n            return samples\n        \n    def __iter__(self):\n        self.open()\n        return self\n        \n\nclass AudioOutput:\n    \"\"\"\n    Audio output stream to a speaker.\n    \"\"\"\n    def __init__(self, device, sample_rate, chunk_size=4096):\n        self.stream = None\n        \n        if device is None:\n            self.device_id = None\n            logging.warning(f\"creating pass-through audio output without a device\")\n            return\n            \n        self.interface = pa.PyAudio()\n        self.device_info = find_audio_device(device, self.interface)\n        self.device_id = self.device_info['index']\n        self.chunk_size = chunk_size\n        self.sample_rate = sample_rate\n        self.requested_rate = sample_rate\n        \n        print('Audio Output Device:')\n        pprint.pprint(self.device_info)\n        \n        self.open()\n    \n    def __del__(self):\n        if self.device_id is None:\n            return\n            \n        self.close()\n        self.interface.terminate()\n        \n    def open(self):\n        if self.stream or self.device_id is None:\n            return\n            \n        try:\n            self.stream = self.interface.open(format=pa.paFloat32,\n                            channels=1, rate=self.sample_rate,\n                            frames_per_buffer=self.chunk_size,\n                            output=True, output_device_index=self.device_id)\n        except:\n            self.sample_rate = int(self.device_info['defaultSampleRate'])\n            logging.error(f\"failed to open audio output device with sample_rate={self.requested_rate}, trying again with sample_rate={self.sample_rate}\")\n            \n            self.stream = self.interface.open(format=pa.paFloat32,\n                            channels=1, rate=self.sample_rate,\n                            frames_per_buffer=self.chunk_size,\n                            output=True, output_device_index=self.device_id)\n        \n        logging.info(f\"opened audio output device {self.device_id} ({self.device_info['name']})\")\n        \n    def close(self):\n        if self.stream is not None:\n            self.stream.stop_stream()\n            self.stream.close()\n            self.stream = None\n       \n    def write(self, samples):\n        if self.device_id is None:\n            return\n            \n        self.open()\n        samples = audio_to_float(samples)\n        \n        if self.requested_rate != self.sample_rate:\n            samples = librosa.resample(samples, self.requested_rate, self.sample_rate)\n            #wav = soundfile.SoundFile('data/audio/resample_test.wav', mode='w', samplerate=self.sample_rate, channels=1)\n            #wav.write(samples)\n            #wav.close()\n            \n        self.stream.write(samples.tobytes())\n        \n        \n#\n# device enumeration\n# \n_audio_device_info = None\n\ndef _get_audio_devices(audio_interface=None):\n    global _audio_device_info\n    \n    if _audio_device_info:\n        return _audio_device_info\n        \n    if audio_interface:\n        interface = audio_interface\n    else:\n        interface = pa.PyAudio()\n        \n    info = interface.get_host_api_info_by_index(0)\n    numDevices = info.get('deviceCount')\n    \n    _audio_device_info = []\n    \n    for i in range(0, numDevices):\n        _audio_device_info.append(interface.get_device_info_by_host_api_device_index(0, i))\n    \n    if not audio_interface:\n        interface.terminate()\n        \n    return _audio_device_info\n     \n     \ndef find_audio_device(device, audio_interface=None):\n    \"\"\"\n    Find an audio device by it's name or ID number.\n    \"\"\"\n    devices = _get_audio_devices(audio_interface)\n    \n    try:\n        device_id = int(device)\n    except ValueError:\n        if not isinstance(device, str):\n            raise ValueError(\"expected either a string or an int for 'device' parameter\")\n            \n        found = False\n        \n        for id, dev in enumerate(devices):\n            if device.lower() == dev['name'].lower():\n                device_id = id\n                found = True\n                break\n                \n        if not found:\n            raise ValueError(f\"could not find audio device with name '{device}'\")\n            \n    if device_id < 0 or device_id >= len(devices):\n        raise ValueError(f\"invalid audio device ID ({device_id})\")\n        \n    return devices[device_id]\n                \n   \ndef list_audio_inputs():\n    \"\"\"\n    Print out information about present audio input devices.\n    \"\"\"\n    devices = _get_audio_devices()\n\n    print('')\n    print('----------------------------------------------------')\n    print(f\" Audio Input Devices\")\n    print('----------------------------------------------------')\n        \n    for i, dev_info in enumerate(devices):    \n        if (dev_info.get('maxInputChannels')) > 0:\n            print(\"Input Device ID {:d} - '{:s}' (inputs={:.0f}) (sample_rate={:.0f})\".format(i,\n                  dev_info.get('name'), dev_info.get('maxInputChannels'), dev_info.get('defaultSampleRate')))\n                 \n    print('')\n    \n    \ndef list_audio_outputs():\n    \"\"\"\n    Print out information about present audio output devices.\n    \"\"\"\n    devices = _get_audio_devices()\n    \n    print('')\n    print('----------------------------------------------------')\n    print(f\" Audio Output Devices\")\n    print('----------------------------------------------------')\n        \n    for i, dev_info in enumerate(devices):  \n        if (dev_info.get('maxOutputChannels')) > 0:\n            print(\"Output Device ID {:d} - '{:s}' (outputs={:.0f}) (sample_rate={:.0f})\".format(i,\n                  dev_info.get('name'), dev_info.get('maxOutputChannels'), dev_info.get('defaultSampleRate')))\n                  \n    print('')\n    \n    \ndef list_audio_devices():\n    \"\"\"\n    Print out information about present audio input and output devices.\n    \"\"\"\n    list_audio_inputs()\n    list_audio_outputs()\n\n              \n\n              "
  },
  {
    "path": "jetson_voice/utils/config.py",
    "content": "#!/usr/bin/env python3\r\n# coding: utf-8\r\n\r\nimport os\r\nimport json\r\nimport pprint\r\nimport logging\r\nimport argparse\r\n\r\n\r\n#\r\n# Default global configuration\r\n#\r\n# This can be overriden at runtime with command-line options (see ConfigArgParser)\r\n# such as --global-config to load your own configuration from json file,\r\n# or by calling config.load('my_config.json')\r\n#\r\n# You can also set the options directly on the 'config' object, e.g.\r\n#\r\n#    config.model_dir = '/path/to/my/models'\r\n#    config.log_level = 'warning'\r\n#\r\n# It's recommended to use one of the methods above instead of changing _default_config directly.\r\n#\r\n_default_global_config = {\r\n    'version' : 0.1,\r\n    'model_dir' : '/jetson-voice/data/networks',\r\n    'model_manifest' : '/jetson-voice/data/networks/manifest.json',\r\n    'default_backend' : 'tensorrt',\r\n    'log_level' : 'info',\r\n    'debug' : False,\r\n    'profile' : False\r\n}\r\n\r\n\r\nclass ConfigDict(dict):\r\n    \"\"\"\r\n    Configuration dict that can be loaded from JSON and has members\r\n    accessible via attributes and can watch for updates to keys.\r\n    \"\"\"\r\n    def __init__(self, *args, path=None, watch=None, **kwargs):\r\n        \"\"\"\r\n        Parameters:\r\n          path (str) -- Path to JSON file to load from\r\n          \r\n          watch (function or dict) -- A callback function that gets called when a key is set.\r\n                                      Should a function signature like my_watch(key, value)\r\n                                      This can also be a dict of key names and functions,\r\n                                      and each function will only be called when it's particular\r\n                                      key has been set.  You can also subclass ConfigDict and\r\n                                      override the __watch__() member function.\r\n        \"\"\"                                \r\n                                         \r\n        super(ConfigDict, self).__init__(*args, **kwargs)\r\n        \r\n        self.__dict__['path'] = path\r\n        self.__dict__['watch'] = watch\r\n        \r\n        for x in args:\r\n            if isinstance(x, dict):\r\n                for y in x:\r\n                    self.__watch__(y, x[y])\r\n                    \r\n        for x in kwargs:\r\n            self.__watch__(x, kwargs[x])\r\n               \r\n        if path:\r\n            self.load(path)\r\n            \r\n    def load(self, path, clear=False):\r\n        \"\"\"\r\n        Load from JSON file.\r\n        \"\"\"\r\n        from .resource import find_resource  # import here to avoid circular dependency\r\n        \r\n        path = find_resource(path)\r\n        self.__dict__['path'] = path\r\n        \r\n        if clear:\r\n            self.clear()\r\n            \r\n        with open(path) as file:\r\n            config_dict = json.load(file)\r\n        \r\n        self.update(config_dict)\r\n        \r\n    def __getattr__(self, attr):\r\n        if attr in self.__dict__:\r\n            return self.__dict__[attr]\r\n        else:\r\n            return self[attr]\r\n        \r\n    def __setattr__(self, attr, value):\r\n        if attr in self.__dict__:\r\n            self.__dict__[attr] = value\r\n        else:\r\n            self[attr] = value\r\n        \r\n    def __setitem__(self, key, value):\r\n        if isinstance(value, dict):\r\n            value = ConfigDict(value, watch=self.watch)\r\n            value.__dict__['path'] = self.path\r\n            \r\n        super(ConfigDict, self).__setitem__(key, value)\r\n        self.__watch__(key, value)\r\n    \r\n    def __watch__(self, key, value):\r\n        #print(f'watch {key} -> {value}')\r\n\r\n        if not self.watch:\r\n            return\r\n            \r\n        if isinstance(self.watch, dict):\r\n            if key in self.watch:\r\n                self.watch[key](key, value)\r\n        else:\r\n            self.watch(key, value)\r\n            \r\n    def __str__(self):\r\n        return pprint.pformat(self)\r\n        \r\n    #def __repr__(self):\r\n    #    return pprint.saferepr(self)\r\n        \r\n    def setdefault(self, key, default=None):\r\n        if isinstance(default, dict):\r\n            value = ConfigDict(value, watch=self.watch)\r\n            value.__dict__['path'] = self.path\r\n            \r\n        changed = key not in self\r\n        value = super(ConfigDict, self).setdefault(key, default)\r\n        \r\n        if changed: \r\n            self.__watch__(key, value)\r\n        \r\n    def update(self, *args, **kwargs):\r\n        for k, v in dict(*args, **kwargs).items():\r\n            self[k] = v\r\n        \r\n\r\n#\r\n# logging handlers\r\n#\r\nlogging.basicConfig(format='[%(asctime)s] %(filename)s:%(lineno)d - %(message)s', datefmt=\"%Y-%m-%d %H:%M:%S\", level=logging.INFO) \r\n\r\nglobal_config = None\r\n\r\ndef _set_log_level(key, value):\r\n    log_value = value.upper()\r\n    \r\n    if log_value == 'VERBOSE':\r\n        log_value = 'DEBUG'\r\n        \r\n    log_level = getattr(logging, log_value, None)\r\n    \r\n    if not isinstance(log_level, int):\r\n        raise ValueError(f'Invalid log level: {value}')\r\n       \r\n    logging.getLogger().setLevel(log_level)\r\n    logging.debug(f'set logging level to {value}')\r\n\r\n    if global_config is not None and value.upper() == 'DEBUG':\r\n        global_config['debug'] = True\r\n    \r\n#\r\n# global config definition\r\n#\r\nglobal_config = ConfigDict(_default_global_config, watch={'log_level':_set_log_level})\r\n\r\nif global_config.log_level.upper() == 'DEBUG':\r\n    global_config['debug'] = True\r\n    \r\nlogging.debug(f'global config:\\n{global_config}')\r\n\r\n\r\n#\r\n# custom arg parser\r\n#\r\nclass ConfigArgParser(argparse.ArgumentParser):\r\n    \"\"\"\r\n    ArgumentParser that provides global configuration options.\r\n    \"\"\"\r\n    def __init__(self, *args, **kwargs):\r\n        super(ConfigArgParser, self).__init__(*args, **kwargs)\r\n    \r\n        self.add_argument('--global-config', default=None, type=str, help='path to JSON file to load global configuration from')\r\n        self.add_argument('--model-dir', default=_default_global_config['model_dir'], help=f\"sets the root path of the models (default '{_default_global_config['model_dir']}')\")\r\n        self.add_argument('--model-manifest', default=_default_global_config['model_manifest'], help=f\"sets the path to the model manifest file (default '{_default_global_config['model_manifest']}')\")\r\n        self.add_argument('--list-models', action='store_true', help='lists the available models (from $model_dir/manifest.json)')\r\n        self.add_argument('--default-backend', default=_default_global_config['default_backend'], help=f\"sets the default backend to use for model execution (default '{_default_global_config['default_backend']}')\")\r\n        self.add_argument('--profile', action='store_true', help='enables model performance profiling')\r\n        self.add_argument('--verbose', action='store_true', help='sets the logging level to verbose')\r\n        self.add_argument('--debug', action='store_true', help='sets the logging level to debug')\r\n        \r\n        log_levels = ['debug', 'verbose', 'info', 'warning', 'error', 'critical']\r\n        \r\n        self.add_argument('--log-level', default=_default_global_config['log_level'], type=str, choices=log_levels,\r\n                          help=f\"sets the logging level to one of the options above (default={_default_global_config['log_level']})\")\r\n        \r\n    def parse_args(self, *args, **kwargs):\r\n        args = super(ConfigArgParser, self).parse_args(*args, **kwargs)\r\n        \r\n        global_config.log_level = args.log_level\r\n        global_config.model_dir = args.model_dir\r\n        \r\n        global_config.model_manifest = args.model_manifest\r\n        global_config.default_backend = args.default_backend\r\n        \r\n        if args.profile:\r\n            global_config.profile = True\r\n            \r\n        if args.verbose:\r\n            global_config.log_level = 'verbose'\r\n            \r\n        if args.debug:\r\n            global_config.log_level = 'debug'\r\n        \r\n        if args.global_config:\r\n            global_config.load(args.global_config)\r\n            \r\n        if args.list_models:\r\n            from .resource import list_models\r\n            list_models()\r\n            \r\n        logging.debug(f'global config:\\n{global_config}')    \r\n        return args\r\n\r\n"
  },
  {
    "path": "jetson_voice/utils/resource.py",
    "content": "#!/usr/bin/env python3\r\n# coding: utf-8\r\n\r\nimport os\r\nimport json\r\nimport time\r\nimport tqdm\r\nimport pprint\r\nimport logging\r\nimport tarfile\r\nimport urllib\r\nimport importlib\r\n\r\nfrom .config import global_config, ConfigDict\r\n\r\n\r\ndef find_resource(path):\r\n    \"\"\"\r\n    Find a resource by checking some common paths.\r\n    \"\"\"\r\n    if os.path.exists(path):\r\n        return path\r\n        \r\n    search_dirs = [global_config.model_dir,\r\n                   os.path.join(global_config.model_dir, 'asr'),\r\n                   os.path.join(global_config.model_dir, 'nlp'),\r\n                   os.path.join(global_config.model_dir, 'tts')]\r\n    \r\n    for search_dir in search_dirs:\r\n        search_path = os.path.join(search_dir, path)\r\n        \r\n        if os.path.exists(search_path):\r\n            return search_path\r\n    \r\n    raise IOError(f\"failed to locate resource '{path}'\")\r\n\r\n\r\ndef load_resource(resource, factory_map, *args, **kwargs):\r\n    \"\"\"\r\n    Load an instance of a resource from a config or service name.\r\n    The factory_map dict maps the backend names to class names.\r\n    Returns the resource instance, or the config if factory_map is null.\r\n    \"\"\"\r\n    if isinstance(resource, str):\r\n        root, ext = os.path.splitext(resource)\r\n        \r\n        if len(ext) > 0:\r\n            ext = ext.lower()\r\n            \r\n            if ext == '.json':\r\n                config = ConfigDict(path=resource)\r\n            elif ext == '.onnx' or ext == '.engine' or ext == '.plan':\r\n                config = ConfigDict(path=root + '.json')\r\n            else:\r\n                raise ValueError(f\"resource '{resource}' has invalid extension '{ext}'\")\r\n        else:\r\n            manifest = download_model(resource)\r\n\r\n            if manifest['type'] == 'model':\r\n                config = ConfigDict(path=get_model_config_path(manifest=manifest))\r\n            else:\r\n                config = ConfigDict(backend=manifest['backend'], type=manifest['name'])\r\n    \r\n    elif isinstance(resource, ConfigDict):\r\n        config = resource\r\n    elif isinstance(resource, dict):\r\n        config = ConfigDict(resource)\r\n    else:\r\n        raise ValueError(f\"expected string or dict type, instead got {type(resource).__name__}\")\r\n    \r\n    config.setdefault('backend', global_config.default_backend)\r\n    \r\n    if factory_map is None:\r\n        return config\r\n        \r\n    if config.backend not in factory_map:\r\n        raise ValueError(f\"'{config.path}' has invalid backend '{config.backend}' (valid options are: {', '.join(factory_map.keys())})\")\r\n        \r\n    class_name = factory_map[config.backend].rsplit(\".\", 1)\r\n    class_type = getattr(importlib.import_module(class_name[0]), class_name[1])\r\n    \r\n    logging.debug(f\"creating instance of {factory_map[config.backend]} for '{config.path}' (backend {config.backend})\")\r\n    logging.debug(class_type)\r\n    \r\n    return class_type(config, *args, **kwargs)\r\n    \r\n    \r\ndef load_model(config, dynamic_shapes=None):\r\n    \"\"\"\r\n    Loads an ONNX model through a backend (either TensorRT or onnxruntime)\r\n    \"\"\"\r\n    factory_map = {\r\n        'tensorrt' : 'jetson_voice.backends.tensorrt.TRTModel',\r\n        'onnxruntime' : 'jetson_voice.backends.onnxruntime.OnnxRuntimeModel'\r\n    }\r\n    \r\n    config.setdefault('backend', global_config.default_backend)\r\n    config.setdefault('model_path', os.path.splitext(config.path)[0] + '.onnx')\r\n    \r\n    if not os.path.exists(config.model_path):\r\n        model_path = os.path.join(os.path.dirname(config.path), config.model_path)\r\n        \r\n        if not os.path.exists(model_path):\r\n            raise IOError(f\"couldn't find file '{config.model_path}'\")\r\n        else:\r\n            config.model_path = model_path\r\n\r\n    if config.backend not in factory_map:\r\n        raise ValueError(f\"'{config.path}' has invalid backend '{config.backend}' (valid options are: {', '.join(factory_map.keys())})\")\r\n        \r\n    class_name = factory_map[config.backend].rsplit(\".\", 1)\r\n    class_type = getattr(importlib.import_module(class_name[0]), class_name[1])\r\n    \r\n    logging.info(f\"loading model '{config.model_path}' with {factory_map[config.backend]}\")\r\n    logging.debug(class_type)\r\n    \r\n    return class_type(config, dynamic_shapes=dynamic_shapes)\r\n    \r\n    \r\ndef load_models_manifest(path=None):\r\n    \"\"\"\r\n    Load the models manifest file.\r\n    If the path isn't overriden, it will use the default 'data/networks/manifest.json'\r\n    \"\"\"\r\n    if path is None:\r\n        path = global_config.model_manifest\r\n        \r\n    with open(path) as file:\r\n        manifest = json.load(file)\r\n        \r\n    for key in manifest:\r\n        manifest[key].setdefault('name', key)\r\n        manifest[key].setdefault('config', key + '.json')\r\n        manifest[key].setdefault('type', 'model')\r\n        \r\n    return manifest\r\n    \r\n  \r\ndef find_model_manifest(name):\r\n    \"\"\"\r\n    Find a model manifest entry by name / alias.\r\n    \"\"\"\r\n    manifest = load_models_manifest()\r\n    \r\n    for key in manifest:\r\n        if key.lower() == name.lower():\r\n            return manifest[key]\r\n        \r\n        if 'alias' in manifest[key]:\r\n            if isinstance(manifest[key]['alias'], str):\r\n                aliases = [manifest[key]['alias']]\r\n            else:\r\n                aliases = manifest[key]['alias']\r\n                \r\n            for alias in aliases:\r\n                if alias.lower() == name.lower():\r\n                    return manifest[key]\r\n      \r\n    raise ValueError(f\"could not find '{name}' in manifest '{global_config.model_manifest}'\")\r\n    \r\n \r\ndef download_model(name, max_attempts=10, retry_time=5):\r\n    \"\"\"\r\n    Download a model if it hasn't already been downloaded.\r\n    \"\"\"\r\n    manifest = find_model_manifest(name)\r\n    \r\n    if manifest is None:\r\n        return None\r\n      \r\n    if manifest['type'] != 'model':\r\n        return manifest\r\n        \r\n    if os.path.exists(get_model_config_path(manifest=manifest)):\r\n        return manifest\r\n\r\n    class DownloadProgressBar(tqdm.tqdm):\r\n        def update_to(self, b=1, bsize=1, tsize=None):\r\n            if tsize is not None:\r\n                self.total = tsize\r\n            self.update(b * bsize - self.n)\r\n        \r\n    def attempt_download(attempt):\r\n        logging.info(f\"downloading '{manifest['name']}' from {manifest['url']} (attempt {attempt} of {max_attempts})\")\r\n\r\n        with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=manifest['name']) as t:\r\n            try:\r\n                filename, _ = urllib.request.urlretrieve(manifest['url'], reporthook=t.update_to)\r\n            except Exception as error:\r\n                t.close()\r\n                logging.error(error)\r\n                return None\r\n                \r\n            return filename\r\n        \r\n    for attempt in range(1, max_attempts+1):\r\n        filename = attempt_download(attempt)\r\n        \r\n        if filename is not None:\r\n            break\r\n            \r\n        logging.error(f\"failed to download '{manifest['name']}' from {manifest['url']} (attempt {attempt} of {max_attempts})\")\r\n        \r\n        if attempt == max_attempts:\r\n            raise ValueError(f\"failed to download '{manifest['name']}' from {manifest['url']} (max attempts exceeded)\")\r\n            \r\n        logging.info(f\"waiting {retry_time} seconds before trying again...\")\r\n        time.sleep(retry_time)\r\n        \r\n    logging.info(f\"extracting {filename} to {os.path.join(global_config.model_dir, manifest['domain'], manifest['name'])}\")\r\n    \r\n    with tarfile.open(filename, \"r:gz\") as tar:\r\n        tar.list()\r\n        tar.extractall(path=os.path.join(global_config.model_dir, manifest['domain']))\r\n\r\n    os.remove(filename)\r\n    return manifest\r\n        \r\n    \r\ndef get_model_config_path(name=None, manifest=None):\r\n    \"\"\"\r\n    Gets the path to the model config from it's name or manifest entry.\r\n    \"\"\"\r\n    if name is None and manifest is None:\r\n        raise ValueError('must specify either name or manifest arguments')\r\n        \r\n    if manifest is None:\r\n        manifest = find_model_manifest(name)\r\n        \r\n    if manifest['type'] != 'model':\r\n        raise ValueError(f\"resource '{manifest['name']}' is not a model (type='{manifest['type']}')\")\r\n    \r\n    if len(os.path.dirname(manifest['config'])) > 0:  # if full path is specified\r\n        return os.path.join(global_config.model_dir, manifest['domain'], manifest['config'])\r\n    else:  \r\n        return os.path.join(global_config.model_dir, manifest['domain'], manifest['name'], manifest['config'])\r\n    \r\n   \r\ndef list_models():\r\n    \"\"\"\r\n    Print out the models available.\r\n    \"\"\"\r\n    manifest = load_models_manifest()\r\n    \r\n    print('')\r\n    print('----------------------------------------------------')\r\n    print(f\" Models\")\r\n    print('----------------------------------------------------')\r\n\r\n    for key in list(manifest):\r\n        if manifest[key]['type'] != 'model':\r\n            manifest.pop(key)\r\n            \r\n    pprint.pprint(manifest)\r\n\r\n    print('')"
  },
  {
    "path": "jetson_voice/utils/softmax.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nimport numpy as np\n\n    \ndef softmax(x, theta=1.0, axis=None):\n    \"\"\"\n    Compute the softmax of each element along an axis of x.\n\n    Parameters\n    ----------\n      x: ND-Array. Probably should be floats.\n    \n      theta (optional): float parameter, used as a multiplier\n                        prior to exponentiation. Default = 1.0\n        \n      axis (optional): axis to compute values along. Default is the\n                       first non-singleton axis.\n\n    Returns an array the same size as X. The result will sum to 1\n    along the specified axis.\n    \"\"\"\n    y = np.atleast_2d(x)\n\n    # find axis\n    if axis is None:\n        axis = next(j[0] for j in enumerate(y.shape) if j[1] > 1)\n\n    # multiply y against the theta parameter,\n    y = y * float(theta)\n\n    # subtract the max for numerical stability\n    y = y - np.expand_dims(np.max(y, axis = axis), axis)\n\n    # exponentiate y\n    y = np.exp(y)\n\n    # take the sum along the specified axis\n    ax_sum = np.expand_dims(np.sum(y, axis = axis), axis)\n\n    # finally: divide elementwise\n    p = y / ax_sum\n\n    # flatten if X was 1D\n    if len(x.shape) == 1: p = p.flatten()\n\n    return p\n\n\ndef normalize_logits(logits):\n    \"\"\"\n    Normalize logits such that they are distributed between [0,1]\n    \"\"\"\n    return np.exp(logits - np.log(np.sum(np.exp(logits), axis=-1, keepdims=True)))           \n\n              "
  },
  {
    "path": "patches/nemo/1.0.0rc1/exportable.original.py",
    "content": "# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport os\nfrom abc import ABC\nfrom collections import defaultdict\nfrom enum import Enum\nfrom typing import Dict\n\nimport onnx\nimport torch\n\nfrom nemo.core.classes import typecheck\nfrom nemo.core.neural_types import AxisKind, NeuralType\nfrom nemo.utils import logging\nfrom nemo.utils.export_utils import replace_for_export\n\ntry:\n    import onnx_graphsurgeon as gs\n\n    ONNX_GRAPHSURGEON_AVAILABLE = True\n\nexcept (ImportError, ModuleNotFoundError):\n    ONNX_GRAPHSURGEON_AVAILABLE = False\n\n__all__ = ['ExportFormat', 'Exportable']\n\n\nclass ExportFormat(Enum):\n    \"\"\"Which format to use when exporting a Neural Module for deployment\"\"\"\n\n    ONNX = (1,)\n    TORCHSCRIPT = (2,)\n\n\n_EXT_DICT = {\n    \".pt\": ExportFormat.TORCHSCRIPT,\n    \".onnx\": ExportFormat.ONNX,\n}\n\n\nclass Exportable(ABC):\n    \"\"\"\n    This Interface should be implemented by particular classes derived from nemo.core.NeuralModule or nemo.core.ModelPT.\n    It gives these entities ability to be exported for deployment to formats such as ONNX.\n    \"\"\"\n\n    @staticmethod\n    def get_format(filename: str):\n        _, ext = os.path.splitext(filename)\n        try:\n            return _EXT_DICT[ext]\n        except KeyError:\n            raise ValueError(f\"Export file {filename} extension does not correspond to any export format!\")\n\n    @property\n    def input_module(self):\n        return self\n\n    @property\n    def output_module(self):\n        return self\n\n    def get_input_names(self, input_example):\n        if isinstance(input_example, Dict):\n            input_names = list(input_example.keys())\n        else:\n            if not (hasattr(self, 'input_types')):\n                raise NotImplementedError(\n                    'For export to work you must define input_types or pass names in input_example'\n                )\n            input_names = list(self.input_types.keys())\n        # remove unnecessary inputs for input_ports\n        for name in self.disabled_deployment_input_names:\n            input_names.remove(name)\n        return input_names\n\n    def get_output_names(self, output_example):\n        if isinstance(output_example, Dict):\n            output_names = list(output_example.keys())\n        else:\n            if not (hasattr(self, 'output_types')):\n                raise NotImplementedError(\n                    'For export to work you must define output_types or pass names in output_example'\n                )\n            output_names = list(self.output_types.keys())\n            # remove unnecessary inputs for input_ports\n        for name in self.disabled_deployment_output_names:\n            output_names.remove(name)\n        return output_names\n\n    def get_input_dynamic_axes(self, input_names):\n        dynamic_axes = defaultdict(list)\n        for name in input_names:\n            dynamic_axes = {\n                **dynamic_axes,\n                **self._extract_dynamic_axes(name, self.input_types[name]),\n            }\n        return dynamic_axes\n\n    def get_output_dynamic_axes(self, output_names):\n        dynamic_axes = defaultdict(list)\n        for name in output_names:\n            dynamic_axes = {\n                **dynamic_axes,\n                **self._extract_dynamic_axes(name, self.output_types[name]),\n            }\n        return dynamic_axes\n\n    def export(\n        self,\n        output: str,\n        input_example=None,\n        output_example=None,\n        verbose=False,\n        export_params=True,\n        do_constant_folding=True,\n        keep_initializers_as_inputs=False,\n        onnx_opset_version: int = 12,\n        try_script: bool = False,\n        set_eval: bool = True,\n        check_trace: bool = True,\n        use_dynamic_axes: bool = True,\n        dynamic_axes=None,\n        check_tolerance=0.01,\n        forward_method=None,\n    ):\n\n        qual_name = self.__module__ + '.' + self.__class__.__qualname__\n        output_descr = qual_name + ' exported to ONNX'\n        exported = ([output], [output_descr])\n\n        try:\n            # Disable typechecks\n            typecheck.set_typecheck_enabled(enabled=False)\n\n            # Allow user to completely override forward method to export\n            if forward_method is None and hasattr(type(self), \"forward_for_export\"):\n                forward_method = type(self).forward_for_export\n\n            if forward_method:\n                old_forward_method = type(self).forward\n                type(self).forward = forward_method\n\n            # Set module to eval mode\n            if set_eval:\n                self.eval()\n\n            format = self.get_format(output)\n            self._prepare_for_export()\n\n            with torch.jit.optimized_execution(True):\n                jitted_model = None\n                if try_script:\n                    try:\n                        jitted_model = torch.jit.script(self)\n                    except Exception as e:\n                        print(\"jit.script() failed!\", e)\n\n            if input_example is None:\n                input_example = self.input_module.input_example()\n\n            with torch.jit.optimized_execution(True):\n                if format == ExportFormat.TORCHSCRIPT:\n                    if isinstance(input_example, Dict):\n                        input_example = tuple(input_example.values())\n\n                    if jitted_model is None:\n                        jitted_model = torch.jit.trace(\n                            self,\n                            input_example,\n                            strict=False,\n                            optimize=True,\n                            check_trace=check_trace,\n                            check_tolerance=check_tolerance,\n                        )\n                    jitted_model.save(output)\n                    assert os.path.exists(output)\n\n                elif format == ExportFormat.ONNX:\n                    if jitted_model is None:\n                        jitted_model = self\n                    if output_example is None:\n                        if isinstance(input_example, tuple):\n                            output_example = self.forward(*input_example)\n                        else:\n                            output_example = self.forward(input_example)\n\n                    input_names = self.input_module.get_input_names(input_example)\n                    output_names = self.output_module.get_output_names(output_example)\n\n                    # dynamic axis is a mapping from input/output_name => list of \"dynamic\" indices\n                    if dynamic_axes is None and use_dynamic_axes:\n                        dynamic_axes = self.input_module.get_input_dynamic_axes(input_names)\n                        dynamic_axes = {**dynamic_axes, **self.output_module.get_output_dynamic_axes(output_names)}\n\n                    if isinstance(input_example, Dict):\n                        input_example = tuple(input_example.values())\n\n                    torch.onnx.export(\n                        jitted_model,\n                        input_example,\n                        output,\n                        input_names=input_names,\n                        output_names=output_names,\n                        verbose=verbose,\n                        export_params=export_params,\n                        do_constant_folding=do_constant_folding,\n                        keep_initializers_as_inputs=keep_initializers_as_inputs,\n                        dynamic_axes=dynamic_axes,\n                        opset_version=onnx_opset_version,\n                        example_outputs=output_example,\n                    )\n\n                    # Verify the model can be read, and is valid\n                    onnx_model = onnx.load(output)\n                    onnx.checker.check_model(onnx_model, full_check=True)\n\n                    if do_constant_folding:\n                        if not ONNX_GRAPHSURGEON_AVAILABLE:\n                            logging.info(\n                                f\"onnx-graphsurgeon module is not instlled.\"\n                                \"That may result in suboptimal optimization of exported ONNX graph (including unneeded DOUBLE initializers).\"\n                                \"Please follow the instructions available at:\"\n                                \"https://github.com/NVIDIA/TensorRT/tree/master/tools/onnx-graphsurgeon\"\n                                \"to install onnx-graphsurgeon from source to improve exported graph.\"\n                            )\n                        else:\n                            # This pass is to remove/recast certain constants that are generated as 'double'\n                            # Those constants break ONNX -> TRT conversion (TRT does not support 'double' as of 7.2)\n                            # Can probably be removed once TRT has automatic downcast for double.\n                            # However, it may still be useful even then as it seems to always make the graph shorter.\n                            graph = gs.import_onnx(onnx_model)\n                            onnx_model = gs.export_onnx(graph.fold_constants().cleanup())\n                            onnx.checker.check_model(onnx_model, full_check=True)\n                            onnx.save(onnx_model, output)\n                else:\n                    raise ValueError(f'Encountered unknown export format {format}.')\n        finally:\n            typecheck.set_typecheck_enabled(enabled=True)\n            if forward_method:\n                type(self).forward = old_forward_method\n        return exported\n\n    @property\n    def disabled_deployment_input_names(self):\n        \"\"\"Implement this method to return a set of input names disabled for export\"\"\"\n        return set()\n\n    @property\n    def disabled_deployment_output_names(self):\n        \"\"\"Implement this method to return a set of output names disabled for export\"\"\"\n        return set()\n\n    @property\n    def supported_export_formats(self):\n        \"\"\"Implement this method to return a set of export formats supported. Default is all types.\"\"\"\n        return set([ExportFormat.ONNX, ExportFormat.TORCHSCRIPT])\n\n    @staticmethod\n    def _extract_dynamic_axes(name: str, ntype: NeuralType):\n        \"\"\"\n        Implement this method to provide dynamic axes id for ONNX export.\n        By default, this method will extract BATCH and TIME dimension ids from each provided input/output name argument.\n\n        For example, if module/model accepts argument named \"input_signal\" with type corresponding to [Batch, Time, Dim]\n        shape, then the returned result should contain \"input_signal\" -> [0, 1] because Batch and Time are dynamic axes\n        as they can change from call to call during inference.\n\n        Args:\n            name: Name of input or output parameter\n            ntype: Corresponding Neural Type\n\n        Returns:\n\n        \"\"\"\n        dynamic_axes = defaultdict(list)\n        if ntype.axes:\n            for ind, axis in enumerate(ntype.axes):\n                if axis.kind in [AxisKind.Batch, AxisKind.Time, AxisKind.Width, AxisKind.Height]:\n                    dynamic_axes[name].append(ind)\n        return dynamic_axes\n\n    def _prepare_for_export(self, replace_1D_2D=False):\n        \"\"\"\n        Override this method to prepare module for export. This is in-place operation.\n        Base version does common necessary module replacements (Apex etc)\n        \"\"\"\n        replace_for_export(self, replace_1D_2D)\n"
  },
  {
    "path": "patches/nemo/1.0.0rc1/exportable.py",
    "content": "# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport os\nfrom abc import ABC\nfrom collections import defaultdict\nfrom enum import Enum\nfrom typing import Dict\n\nimport onnx\nimport torch\n\nfrom nemo.core.classes import typecheck\nfrom nemo.core.neural_types import AxisKind, NeuralType\nfrom nemo.utils import logging\nfrom nemo.utils.export_utils import replace_for_export\n\ntry:\n    import onnx_graphsurgeon as gs\n\n    ONNX_GRAPHSURGEON_AVAILABLE = True\n\nexcept (ImportError, ModuleNotFoundError):\n    ONNX_GRAPHSURGEON_AVAILABLE = False\n\n__all__ = ['ExportFormat', 'Exportable']\n\n\nclass ExportFormat(Enum):\n    \"\"\"Which format to use when exporting a Neural Module for deployment\"\"\"\n\n    ONNX = (1,)\n    TORCHSCRIPT = (2,)\n\n\n_EXT_DICT = {\n    \".pt\": ExportFormat.TORCHSCRIPT,\n    \".onnx\": ExportFormat.ONNX,\n}\n\n\nclass Exportable(ABC):\n    \"\"\"\n    This Interface should be implemented by particular classes derived from nemo.core.NeuralModule or nemo.core.ModelPT.\n    It gives these entities ability to be exported for deployment to formats such as ONNX.\n    \"\"\"\n\n    @staticmethod\n    def get_format(filename: str):\n        _, ext = os.path.splitext(filename)\n        try:\n            return _EXT_DICT[ext]\n        except KeyError:\n            raise ValueError(f\"Export file {filename} extension does not correspond to any export format!\")\n\n    @property\n    def input_module(self):\n        return self\n\n    @property\n    def output_module(self):\n        return self\n\n    def get_input_names(self, input_example):\n        if isinstance(input_example, Dict):\n            input_names = list(input_example.keys())\n        else:\n            if not (hasattr(self, 'input_types')):\n                raise NotImplementedError(\n                    'For export to work you must define input_types or pass names in input_example'\n                )\n            input_names = list(self.input_types.keys())\n        # remove unnecessary inputs for input_ports\n        for name in self.disabled_deployment_input_names:\n            input_names.remove(name)\n        return input_names\n\n    def get_output_names(self, output_example):\n        if isinstance(output_example, Dict):\n            output_names = list(output_example.keys())\n        else:\n            if not (hasattr(self, 'output_types')):\n                raise NotImplementedError(\n                    'For export to work you must define output_types or pass names in output_example'\n                )\n            output_names = list(self.output_types.keys())\n            # remove unnecessary inputs for input_ports\n        for name in self.disabled_deployment_output_names:\n            output_names.remove(name)\n        return output_names\n\n    def get_input_dynamic_axes(self, input_names):\n        dynamic_axes = defaultdict(list)\n        for name in input_names:\n            dynamic_axes = {\n                **dynamic_axes,\n                **self._extract_dynamic_axes(name, self.input_types[name]),\n            }\n        return dynamic_axes\n\n    def get_output_dynamic_axes(self, output_names):\n        dynamic_axes = defaultdict(list)\n        for name in output_names:\n            dynamic_axes = {\n                **dynamic_axes,\n                **self._extract_dynamic_axes(name, self.output_types[name]),\n            }\n        return dynamic_axes\n\n    def export(\n        self,\n        output: str,\n        input_example=None,\n        output_example=None,\n        verbose=False,\n        export_params=True,\n        do_constant_folding=True,\n        keep_initializers_as_inputs=False,\n        onnx_opset_version: int = 12,\n        try_script: bool = False,\n        set_eval: bool = True,\n        check_trace: bool = True,\n        use_dynamic_axes: bool = True,\n        dynamic_axes=None,\n        check_tolerance=0.01,\n        forward_method=None,\n    ):\n\n        qual_name = self.__module__ + '.' + self.__class__.__qualname__\n        output_descr = qual_name + ' exported to ONNX'\n        exported = ([output], [output_descr])\n\n        try:\n            # Disable typechecks\n            typecheck.set_typecheck_enabled(enabled=False)\n\n            # Allow user to completely override forward method to export\n            if forward_method is None and hasattr(type(self), \"forward_for_export\"):\n                forward_method = type(self).forward_for_export\n\n            if forward_method:\n                old_forward_method = type(self).forward\n                type(self).forward = forward_method\n\n            # Set module to eval mode\n            if set_eval:\n                self.eval()\n\n            format = self.get_format(output)\n            self._prepare_for_export()\n\n            with torch.jit.optimized_execution(True):\n                jitted_model = None\n                if try_script:\n                    try:\n                        jitted_model = torch.jit.script(self)\n                    except Exception as e:\n                        print(\"jit.script() failed!\", e)\n\n            if input_example is None:\n                input_example = self.input_module.input_example()\n\n            with torch.jit.optimized_execution(True):\n                if format == ExportFormat.TORCHSCRIPT:\n                    if isinstance(input_example, Dict):\n                        input_example = tuple(input_example.values())\n\n                    if jitted_model is None:\n                        jitted_model = torch.jit.trace(\n                            self,\n                            input_example,\n                            strict=False,\n                            optimize=True,\n                            check_trace=check_trace,\n                            check_tolerance=check_tolerance,\n                        )\n                    jitted_model.save(output)\n                    assert os.path.exists(output)\n\n                elif format == ExportFormat.ONNX:\n                    if jitted_model is None:\n                        jitted_model = self\n                    if output_example is None:\n                        if isinstance(input_example, tuple):\n                            output_example = self.forward(*input_example)\n                        else:\n                            output_example = self.forward(input_example)\n\n                    input_names = self.input_module.get_input_names(input_example)\n                    output_names = self.output_module.get_output_names(output_example)\n\n                    # dynamic axis is a mapping from input/output_name => list of \"dynamic\" indices\n                    if dynamic_axes is None and use_dynamic_axes:\n                        dynamic_axes = self.input_module.get_input_dynamic_axes(input_names)\n                        dynamic_axes = {**dynamic_axes, **self.output_module.get_output_dynamic_axes(output_names)}\n\n                    if isinstance(input_example, tuple):\n                        logging.info(f'ONNX input_example {len(input_example)}')\n                        \n                        for idx, x in enumerate(input_example):\n                            logging.info(f'  - {idx}  {x.shape}')\n                            \n                        \"\"\"\n                        if len(input_names) < len(input_example):\n                            logging.warning(f'removing extra input_examples to match number of input_names')\n                            input_example = tuple([input_example[x] for x in range(len(input_names))])\n                            logging.warning(f'new number of input_examples:  {len(input_example)}')\n                        \"\"\"\n                        \n                    logging.info(f'ONNX class_name    {type(self).__name__}')\n                    logging.info(f'ONNX input_names   {input_names}')\n                    logging.info(f'ONNX output_names  {output_names}')\n                    logging.info(f'ONNX dynamic_axes  {dynamic_axes}')\n\n                    if isinstance(input_example, Dict):\n                        input_example = tuple(input_example.values())\n\n                    torch.onnx.export(\n                        jitted_model,\n                        input_example,\n                        output,\n                        input_names=input_names,\n                        output_names=output_names,\n                        verbose=verbose,\n                        export_params=export_params,\n                        do_constant_folding=do_constant_folding,\n                        keep_initializers_as_inputs=keep_initializers_as_inputs,\n                        dynamic_axes=dynamic_axes,\n                        opset_version=onnx_opset_version,\n                        example_outputs=output_example,\n                    )\n\n                    # Verify the model can be read, and is valid\n                    onnx_model = onnx.load(output)\n                    onnx.checker.check_model(onnx_model, full_check=True)\n\n                    if do_constant_folding:\n                        if not ONNX_GRAPHSURGEON_AVAILABLE:\n                            logging.info(\n                                f\"onnx-graphsurgeon module is not instlled.\"\n                                \"That may result in suboptimal optimization of exported ONNX graph (including unneeded DOUBLE initializers).\"\n                                \"Please follow the instructions available at:\"\n                                \"https://github.com/NVIDIA/TensorRT/tree/master/tools/onnx-graphsurgeon\"\n                                \"to install onnx-graphsurgeon from source to improve exported graph.\"\n                            )\n                        else:\n                            # This pass is to remove/recast certain constants that are generated as 'double'\n                            # Those constants break ONNX -> TRT conversion (TRT does not support 'double' as of 7.2)\n                            # Can probably be removed once TRT has automatic downcast for double.\n                            # However, it may still be useful even then as it seems to always make the graph shorter.\n                            graph = gs.import_onnx(onnx_model)\n                            onnx_model = gs.export_onnx(graph.fold_constants().cleanup())\n                            onnx.checker.check_model(onnx_model, full_check=True)\n                            onnx.save(onnx_model, output)\n                else:\n                    raise ValueError(f'Encountered unknown export format {format}.')\n        finally:\n            typecheck.set_typecheck_enabled(enabled=True)\n            if forward_method:\n                type(self).forward = old_forward_method\n        return exported\n\n    @property\n    def disabled_deployment_input_names(self):\n        \"\"\"Implement this method to return a set of input names disabled for export\"\"\"\n        return set()\n\n    @property\n    def disabled_deployment_output_names(self):\n        \"\"\"Implement this method to return a set of output names disabled for export\"\"\"\n        return set()\n\n    @property\n    def supported_export_formats(self):\n        \"\"\"Implement this method to return a set of export formats supported. Default is all types.\"\"\"\n        return set([ExportFormat.ONNX, ExportFormat.TORCHSCRIPT])\n\n    @staticmethod\n    def _extract_dynamic_axes(name: str, ntype: NeuralType):\n        \"\"\"\n        Implement this method to provide dynamic axes id for ONNX export.\n        By default, this method will extract BATCH and TIME dimension ids from each provided input/output name argument.\n\n        For example, if module/model accepts argument named \"input_signal\" with type corresponding to [Batch, Time, Dim]\n        shape, then the returned result should contain \"input_signal\" -> [0, 1] because Batch and Time are dynamic axes\n        as they can change from call to call during inference.\n\n        Args:\n            name: Name of input or output parameter\n            ntype: Corresponding Neural Type\n\n        Returns:\n\n        \"\"\"\n        dynamic_axes = defaultdict(list)\n        if ntype.axes:\n            for ind, axis in enumerate(ntype.axes):\n                if axis.kind in [AxisKind.Batch, AxisKind.Time, AxisKind.Width, AxisKind.Height]:\n                    dynamic_axes[name].append(ind)\n        return dynamic_axes\n\n    def _prepare_for_export(self, replace_1D_2D=False):\n        \"\"\"\n        Override this method to prepare module for export. This is in-place operation.\n        Base version does common necessary module replacements (Apex etc)\n        \"\"\"\n        replace_for_export(self, replace_1D_2D)\n"
  },
  {
    "path": "patches/nemo/1.0.0rc1/nlp/__init__.py",
    "content": "# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom nemo.collections.nlp.modules.common.huggingface.albert import AlbertEncoder\nfrom nemo.collections.nlp.modules.common.huggingface.bert import BertEncoder\nfrom nemo.collections.nlp.modules.common.huggingface.distilbert import DistilBertEncoder\nfrom nemo.collections.nlp.modules.common.huggingface.huggingface_utils import (\n    get_huggingface_lm_model,\n    get_huggingface_pretrained_lm_models_list,\n)\nfrom nemo.collections.nlp.modules.common.huggingface.roberta import RobertaEncoder\nfrom nemo.collections.nlp.modules.common.huggingface.mobilebert import MobileBertEncoder\n"
  },
  {
    "path": "patches/nemo/1.0.0rc1/nlp/distilbert.diff",
    "content": "17a18\n> from typing import Dict, Optional\n19a21\n> from nemo.core.neural_types import ChannelType, MaskType, NeuralType\n29a32,53\n>     @property\n>     def input_types(self) -> Optional[Dict[str, NeuralType]]:\n>         \"\"\"\n>         These are ordered incorrectly in bert_module.py WRT to QAModel.forward()\n>         DistilBert doesn't use token_type_ids, but the QAModel still needs them during export.\n>         By re-ordring them, the correct input_names are used during export of the ONNX model.\n>         \"\"\"\n>         return {\n>             \"input_ids\": NeuralType(('B', 'T'), ChannelType()),\n>             \"token_type_ids\": NeuralType(('B', 'T'), ChannelType(), optional=True),\n>             \"attention_mask\": NeuralType(('B', 'T'), MaskType(), optional=True)\n>         }\n> \n>     '''\n>     # note:  disabling the token_type_ids here still leads to incorrect names, because QAModel.forward()\n>     #        still needs the token_type_ids to run the trace, and hence the input_example is still larger\n>     @property\n>     def disabled_deployment_input_names(self):\n>         \"\"\"Implement this method to return a set of input names disabled for export\"\"\"\n>         return ['token_type_ids']\n>     '''\n>     \n34a59\n>         \n\\ No newline at end of file\n"
  },
  {
    "path": "patches/nemo/1.0.0rc1/nlp/distilbert.original.py",
    "content": "# Copyright 2020 The Google AI Language Team Authors and\n# The HuggingFace Inc. team.\n# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom transformers import DistilBertModel\n\nfrom nemo.collections.nlp.modules.common.bert_module import BertModule\nfrom nemo.core.classes import typecheck\n\n__all__ = ['DistilBertEncoder']\n\n\nclass DistilBertEncoder(DistilBertModel, BertModule):\n    \"\"\"\n    Wraps around the Huggingface transformers implementation repository for easy use within NeMo.\n    \"\"\"\n\n    @typecheck()\n    def forward(self, input_ids, attention_mask, token_type_ids=None):\n        # distilBert does not use token_type_ids as the most of the other Bert models\n        res = super().forward(input_ids=input_ids, attention_mask=attention_mask)[0]\n        return res\n"
  },
  {
    "path": "patches/nemo/1.0.0rc1/nlp/distilbert.py",
    "content": "# Copyright 2020 The Google AI Language Team Authors and\n# The HuggingFace Inc. team.\n# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom transformers import DistilBertModel\nfrom typing import Dict, Optional\n\nfrom nemo.collections.nlp.modules.common.bert_module import BertModule\nfrom nemo.core.neural_types import ChannelType, MaskType, NeuralType\nfrom nemo.core.classes import typecheck\n\n__all__ = ['DistilBertEncoder']\n\n\nclass DistilBertEncoder(DistilBertModel, BertModule):\n    \"\"\"\n    Wraps around the Huggingface transformers implementation repository for easy use within NeMo.\n    \"\"\"\n\n    @property\n    def input_types(self) -> Optional[Dict[str, NeuralType]]:\n        \"\"\"\n        These are ordered incorrectly in bert_module.py WRT to QAModel.forward()\n        DistilBert doesn't use token_type_ids, but the QAModel still needs them during export.\n        By re-ordring them, the correct input_names are used during export of the ONNX model.\n        \"\"\"\n        return {\n            \"input_ids\": NeuralType(('B', 'T'), ChannelType()),\n            \"token_type_ids\": NeuralType(('B', 'T'), ChannelType(), optional=True),\n            \"attention_mask\": NeuralType(('B', 'T'), MaskType(), optional=True)\n        }\n\n    '''\n    # note:  disabling the token_type_ids here still leads to incorrect names, because QAModel.forward()\n    #        still needs the token_type_ids to run the trace, and hence the input_example is still larger\n    @property\n    def disabled_deployment_input_names(self):\n        \"\"\"Implement this method to return a set of input names disabled for export\"\"\"\n        return ['token_type_ids']\n    '''\n    \n    @typecheck()\n    def forward(self, input_ids, attention_mask, token_type_ids=None):\n        # distilBert does not use token_type_ids as the most of the other Bert models\n        res = super().forward(input_ids=input_ids, attention_mask=attention_mask)[0]\n        return res\n        "
  },
  {
    "path": "patches/nemo/1.0.0rc1/nlp/huggingface_utils.py",
    "content": "# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nfrom typing import List, Optional\n\nfrom transformers import (\n    ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,\n    ALL_PRETRAINED_CONFIG_ARCHIVE_MAP,\n    BERT_PRETRAINED_MODEL_ARCHIVE_LIST,\n    DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,\n    ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,\n    MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,\n    AlbertConfig,\n    AutoModel,\n    BertConfig,\n    DistilBertConfig,\n    RobertaConfig,\n    MobileBertConfig,\n)\n\nfrom nemo.collections.nlp.modules.common.huggingface.albert import AlbertEncoder\nfrom nemo.collections.nlp.modules.common.huggingface.bert import BertEncoder\nfrom nemo.collections.nlp.modules.common.huggingface.distilbert import DistilBertEncoder\nfrom nemo.collections.nlp.modules.common.huggingface.roberta import RobertaEncoder\nfrom nemo.collections.nlp.modules.common.huggingface.mobilebert import MobileBertEncoder\nfrom nemo.utils import logging\n\n__all__ = [\"get_huggingface_lm_model\", \"get_huggingface_pretrained_lm_models_list\"]\n\n\nHUGGINGFACE_MODELS = {\n    \"BertModel\": {\n        \"default\": \"bert-base-uncased\",\n        \"class\": BertEncoder,\n        \"config\": BertConfig,\n        \"pretrained_model_list\": BERT_PRETRAINED_MODEL_ARCHIVE_LIST,\n    },\n    \"DistilBertModel\": {\n        \"default\": \"distilbert-base-uncased\",\n        \"class\": DistilBertEncoder,\n        \"config\": DistilBertConfig,\n        \"pretrained_model_list\": DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,\n    },\n    \"RobertaModel\": {\n        \"default\": \"roberta-base\",\n        \"class\": RobertaEncoder,\n        \"config\": RobertaConfig,\n        \"pretrained_model_list\": ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,\n    },\n    \"AlbertModel\": {\n        \"default\": \"albert-base-v2\",\n        \"class\": AlbertEncoder,\n        \"config\": AlbertConfig,\n        \"pretrained_model_list\": ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,\n    },\n    \"MobileBertModel\": {\n        \"default\": \"google/mobilebert-uncased\",\n        \"class\": MobileBertEncoder,\n        \"config\": MobileBertConfig,\n        \"pretrained_model_list\": MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,\n    },\n}\n\n\ndef get_huggingface_lm_model(\n    pretrained_model_name: str, config_dict: Optional[dict] = None, config_file: Optional[str] = None,\n):\n    \"\"\"\n    Returns lm model instantiated with Huggingface\n\n    Args:\n        pretrained_mode_name: specify this to instantiate pretrained model from Huggingface,\n            e.g. bert-base-cased. For entire list, see get_huggingface_pretrained_lm_models_list().\n        config_dict: model configuration dictionary used to instantiate Huggingface model from scratch\n        config_file: path to model configuration file used to instantiate Huggingface model from scratch\n\n    Returns:\n        BertModule\n    \"\"\"\n\n    try:\n        automodel = AutoModel.from_pretrained(pretrained_model_name)\n    except Exception as e:\n        raise ValueError(f\"{pretrained_model_name} is not supported by HuggingFace. {e}\")\n\n    model_type = type(automodel).__name__\n    if model_type in HUGGINGFACE_MODELS:\n        model_class = HUGGINGFACE_MODELS[model_type][\"class\"]\n        if config_file:\n            if not os.path.exists(config_file):\n                logging.warning(\n                    f\"Config file was not found at {config_file}. Will attempt to use config_dict or pretrained_model_name.\"\n                )\n            else:\n                config_class = HUGGINGFACE_MODELS[model_type][\"config\"]\n                return model_class(config_class.from_json_file(config_file))\n        if config_dict:\n            config_class = HUGGINGFACE_MODELS[model_type][\"config\"]\n            return model_class(config=config_class(**config_dict))\n        else:\n            return model_class.from_pretrained(pretrained_model_name)\n    else:\n        raise ValueError(f\"Use HuffingFace API directly in NeMo for {pretrained_model_name}\")\n\n\ndef get_huggingface_pretrained_lm_models_list(include_external: bool = False,) -> List[str]:\n    \"\"\"\n    Returns the list of pretrained HuggingFace language models\n    \n    Args:\n        include_external if true includes all HuggingFace model names, not only those supported language models in NeMo.\n    \n    Returns the list of HuggingFace models\n    \"\"\"\n\n    huggingface_models = []\n    if include_external:\n        huggingface_models = list(ALL_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())\n    else:\n        for model in HUGGINGFACE_MODELS:\n            model_names = HUGGINGFACE_MODELS[model][\"pretrained_model_list\"]\n            huggingface_models.extend(model_names)\n    return huggingface_models\n"
  },
  {
    "path": "patches/nemo/1.0.0rc1/nlp/location.txt",
    "content": "nemo/collections/nlp/modules/common/huggingface\r\n\r\nMain branch. Commit 21a17b267fac68d4cdd20f3969a580a0a40dbdb4"
  },
  {
    "path": "patches/nemo/1.0.0rc1/nlp/mobilebert.py",
    "content": "# Copyright 2018 The Google AI Language Team Authors and\n# The HuggingFace Inc. team.\n# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom transformers import MobileBertModel\n\nfrom nemo.collections.nlp.modules.common.bert_module import BertModule\nfrom nemo.core.classes import typecheck\n\n__all__ = ['MobileBertEncoder']\n\n\nclass MobileBertEncoder(MobileBertModel, BertModule):\n    \"\"\"\n    Wraps around the Huggingface transformers implementation repository for easy use within NeMo.\n    \"\"\"\n\n    @typecheck()\n    def forward(self, input_ids, attention_mask, token_type_ids):\n        res = super().forward(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)[0]\n        return res\n"
  },
  {
    "path": "patches/nemo/1.0.0rc1/setup.original.py",
    "content": "# ! /usr/bin/python\n# -*- coding: utf-8 -*-\n\n# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Setup for pip package.\"\"\"\n\nimport codecs\nimport os\nimport subprocess\nimport sys\nfrom distutils import cmd as distutils_cmd\nfrom distutils import log as distutils_log\nfrom itertools import chain\n\nimport setuptools\n\n\ndef is_build_action():\n    if len(sys.argv) <= 1:\n        return False\n\n    BUILD_TOKENS = [\"egg_info\", \"dist\", \"bdist\", \"sdist\", \"install\", \"build\", \"develop\", \"style\", \"clean\"]\n\n    if any([sys.argv[1].startswith(x) for x in BUILD_TOKENS]):\n        return True\n    else:\n        return False\n\n\nif is_build_action():\n    os.environ['NEMO_PACKAGE_BUILDING'] = 'True'\n\nfrom nemo.package_info import (\n    __contact_emails__,\n    __contact_names__,\n    __description__,\n    __download_url__,\n    __homepage__,\n    __keywords__,\n    __license__,\n    __package_name__,\n    __repository_url__,\n    __version__,\n)\n\nif os.path.exists('nemo/README.md'):\n    with open(\"nemo/README.md\", \"r\") as fh:\n        long_description = fh.read()\n    long_description_content_type = \"text/markdown\"\n\nelif os.path.exists('README.rst'):\n    # codec is used for consistent encoding\n    long_description = codecs.open(\n        os.path.join(os.path.abspath(os.path.dirname(__file__)), 'README.rst'), 'r', 'utf-8',\n    ).read()\n    long_description_content_type = \"text/x-rst\"\n\nelse:\n    long_description = 'See ' + __homepage__\n\n\n###############################################################################\n#                             Dependency Loading                              #\n# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #\n\n\ndef req_file(filename, folder=\"requirements\"):\n    with open(os.path.join(folder, filename)) as f:\n        content = f.readlines()\n    # you may also want to remove whitespace characters\n    # Example: `\\n` at the end of each line\n    return [x.strip() for x in content]\n\n\ninstall_requires = req_file(\"requirements.txt\")\n\nextras_require = {\n    # User packages\n    'test': req_file(\"requirements_test.txt\"),\n    # Collections Packages\n    'asr': req_file(\"requirements_asr.txt\"),\n    'cv': req_file(\"requirements_cv.txt\"),\n    'nlp': req_file(\"requirements_nlp.txt\"),\n    'tts': req_file(\"requirements_tts.txt\"),\n}\n\nextras_require['all'] = list(chain(extras_require.values()))\n\n# TTS depends on ASR\nextras_require['tts'] = list(chain([extras_require['tts'], extras_require['asr']]))\n\ntests_requirements = extras_require[\"test\"]\n\n########################## VERSION MISMATCH PATCH #############################\n# REMOVE AFTER 21.03 Container is released !\n\ntry:\n    import torch\n\n    version = torch.__version__\n    SUPPORTED_TORCH_VERSION = f\"torch=={version}\"\n\n    if 'a' in version or 'b' in version:\n        # It is githash release, force to supported Pytorch Lightning branch\n        SUPPORTED_PYTORCH_LIGHTNING = \"pytorch-lightning==1.1.5\"\n    else:\n        # Downgrade torch, pytorch-lightning\n        SUPPORTED_TORCH_VERSION = \"torch<=1.7.1\"\n        SUPPORTED_PYTORCH_LIGHTNING = \"pytorch-lightning==1.1.5\"\n\nexcept (ImportError, ModuleNotFoundError):\n    # Since no torch is installed, pip install torch will install latest torch and latest pytorch lightning\n    SUPPORTED_TORCH_VERSION = \"torch<=1.7.1\"\n    SUPPORTED_PYTORCH_LIGHTNING = \"pytorch-lightning==1.1.5\"\n\ninstall_requires_buffer = []\nfor ix, line in enumerate(install_requires):\n    if 'lightning' in line:\n        install_requires_buffer.append(SUPPORTED_PYTORCH_LIGHTNING)\n    elif 'torch' in line:\n        install_requires_buffer.append(SUPPORTED_TORCH_VERSION)\n\n        # Pytorch 1.7.1 must use torchtext==0.8.0, torchaudio==0.7.2 and torchvision==0.8.2\n        if SUPPORTED_TORCH_VERSION == \"torch<=1.7.1\":\n            install_requires_buffer.append(\"torchvision==0.8.2\")\n            install_requires_buffer.append(\"torchaudio==0.7.2\")\n            install_requires_buffer.append(\"torchtext==0.8.0\")\n\n    else:\n        install_requires_buffer.append(line)\n\n# override install requires\ninstall_requires = install_requires_buffer\n\n###############################################################################\n#                            Code style checkers                              #\n# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #\n\n\nclass StyleCommand(distutils_cmd.Command):\n    __LINE_WIDTH = 119\n    __ISORT_BASE = (\n        'isort '\n        # These two lines makes isort compatible with black.\n        '--multi-line=3 --trailing-comma --force-grid-wrap=0 '\n        f'--use-parentheses --line-width={__LINE_WIDTH} -rc -ws'\n    )\n    __BLACK_BASE = f'black --skip-string-normalization --line-length={__LINE_WIDTH}'\n    description = 'Checks overall project code style.'\n    user_options = [\n        ('scope=', None, 'Folder of file to operate within.'),\n        ('fix', None, 'True if tries to fix issues in-place.'),\n    ]\n\n    def __call_checker(self, base_command, scope, check):\n        command = list(base_command)\n\n        command.append(scope)\n\n        if check:\n            command.extend(['--check', '--diff'])\n\n        self.announce(\n            msg='Running command: %s' % str(' '.join(command)), level=distutils_log.INFO,\n        )\n\n        return_code = subprocess.call(command)\n\n        return return_code\n\n    def _isort(self, scope, check):\n        return self.__call_checker(base_command=self.__ISORT_BASE.split(), scope=scope, check=check,)\n\n    def _black(self, scope, check):\n        return self.__call_checker(base_command=self.__BLACK_BASE.split(), scope=scope, check=check,)\n\n    def _pass(self):\n        self.announce(msg='\\033[32mPASS\\x1b[0m', level=distutils_log.INFO)\n\n    def _fail(self):\n        self.announce(msg='\\033[31mFAIL\\x1b[0m', level=distutils_log.INFO)\n\n    # noinspection PyAttributeOutsideInit\n    def initialize_options(self):\n        self.scope = '.'\n        self.fix = ''\n\n    def run(self):\n        scope, check = self.scope, not self.fix\n        isort_return = self._isort(scope=scope, check=check)\n        black_return = self._black(scope=scope, check=check)\n\n        if isort_return == 0 and black_return == 0:\n            self._pass()\n        else:\n            self._fail()\n            exit(isort_return if isort_return != 0 else black_return)\n\n    def finalize_options(self):\n        pass\n\n\n###############################################################################\n\nsetuptools.setup(\n    name=__package_name__,\n    # Versions should comply with PEP440.  For a discussion on single-sourcing\n    # the version across setup.py and the project code, see\n    # https://packaging.python.org/en/latest/single_source_version.html\n    version=__version__,\n    description=__description__,\n    long_description=long_description,\n    long_description_content_type=long_description_content_type,\n    # The project's main homepage.\n    url=__repository_url__,\n    download_url=__download_url__,\n    # Author details\n    author=__contact_names__,\n    author_email=__contact_emails__,\n    # maintainer Details\n    maintainer=__contact_names__,\n    maintainer_email=__contact_emails__,\n    # The licence under which the project is released\n    license=__license__,\n    classifiers=[\n        # How mature is this project? Common values are\n        #  1 - Planning\n        #  2 - Pre-Alpha\n        #  3 - Alpha\n        #  4 - Beta\n        #  5 - Production/Stable\n        #  6 - Mature\n        #  7 - Inactive\n        'Development Status :: 4 - Beta',\n        # Indicate who your project is intended for\n        'Intended Audience :: Developers',\n        'Intended Audience :: Science/Research',\n        'Intended Audience :: Information Technology',\n        # Indicate what your project relates to\n        'Topic :: Scientific/Engineering',\n        'Topic :: Scientific/Engineering :: Mathematics',\n        'Topic :: Scientific/Engineering :: Image Recognition',\n        'Topic :: Scientific/Engineering :: Artificial Intelligence',\n        'Topic :: Software Development :: Libraries',\n        'Topic :: Software Development :: Libraries :: Python Modules',\n        'Topic :: Utilities',\n        # Pick your license as you wish (should match \"license\" above)\n        'License :: OSI Approved :: Apache Software License',\n        # Supported python versions\n        'Programming Language :: Python :: 3',\n        'Programming Language :: Python :: 3.5',\n        'Programming Language :: Python :: 3.6',\n        'Programming Language :: Python :: 3.7',\n        'Programming Language :: Python :: 3.8',\n        # Additional Setting\n        'Environment :: Console',\n        'Natural Language :: English',\n        'Operating System :: OS Independent',\n    ],\n    packages=setuptools.find_packages(),\n    install_requires=install_requires,\n    setup_requires=['pytest-runner'],\n    tests_require=tests_requirements,\n    # List additional groups of dependencies here (e.g. development\n    # dependencies). You can install these using the following syntax,\n    # $ pip install -e \".[all]\"\n    # $ pip install nemo_toolkit[all]\n    extras_require=extras_require,\n    # Add in any packaged data.\n    include_package_data=True,\n    zip_safe=False,\n    # PyPI package information.\n    keywords=__keywords__,\n    # Custom commands.\n    cmdclass={'style': StyleCommand},\n)\n"
  },
  {
    "path": "patches/nemo/1.0.0rc1/setup.py",
    "content": "# ! /usr/bin/python\n# -*- coding: utf-8 -*-\n\n# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Setup for pip package.\"\"\"\n\nimport codecs\nimport os\nimport subprocess\nimport sys\nfrom distutils import cmd as distutils_cmd\nfrom distutils import log as distutils_log\nfrom itertools import chain\n\nimport setuptools\n\n\ndef is_build_action():\n    if len(sys.argv) <= 1:\n        return False\n\n    BUILD_TOKENS = [\"egg_info\", \"dist\", \"bdist\", \"sdist\", \"install\", \"build\", \"develop\", \"style\", \"clean\"]\n\n    if any([sys.argv[1].startswith(x) for x in BUILD_TOKENS]):\n        return True\n    else:\n        return False\n\n\nif is_build_action():\n    os.environ['NEMO_PACKAGE_BUILDING'] = 'True'\n\nfrom nemo.package_info import (\n    __contact_emails__,\n    __contact_names__,\n    __description__,\n    __download_url__,\n    __homepage__,\n    __keywords__,\n    __license__,\n    __package_name__,\n    __repository_url__,\n    __version__,\n)\n\nif os.path.exists('nemo/README.md'):\n    with open(\"nemo/README.md\", \"r\") as fh:\n        long_description = fh.read()\n    long_description_content_type = \"text/markdown\"\n\nelif os.path.exists('README.rst'):\n    # codec is used for consistent encoding\n    long_description = codecs.open(\n        os.path.join(os.path.abspath(os.path.dirname(__file__)), 'README.rst'), 'r', 'utf-8',\n    ).read()\n    long_description_content_type = \"text/x-rst\"\n\nelse:\n    long_description = 'See ' + __homepage__\n\n\n###############################################################################\n#                             Dependency Loading                              #\n# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #\n\n\ndef req_file(filename, folder=\"requirements\"):\n    with open(os.path.join(folder, filename)) as f:\n        content = f.readlines()\n    # you may also want to remove whitespace characters\n    # Example: `\\n` at the end of each line\n    return [x.strip() for x in content]\n\n\ninstall_requires = req_file(\"requirements.txt\")\n\nextras_require = {\n    # User packages\n    'test': req_file(\"requirements_test.txt\"),\n    # Collections Packages\n    'asr': req_file(\"requirements_asr.txt\"),\n    'cv': req_file(\"requirements_cv.txt\"),\n    'nlp': req_file(\"requirements_nlp.txt\"),\n    'tts': req_file(\"requirements_tts.txt\"),\n}\n\nextras_require['all'] = list(chain(extras_require.values()))\n\n# TTS depends on ASR\nextras_require['tts'] = list(chain([extras_require['tts'], extras_require['asr']]))\n\ntests_requirements = extras_require[\"test\"]\n\n########################## VERSION MISMATCH PATCH #############################\n# REMOVE AFTER 21.03 Container is released !\n\ntry:\n    import torch\n\n    version = torch.__version__\n    SUPPORTED_TORCH_VERSION = f\"torch=={version}\"\n\n    if 'a' in version or 'b' in version:\n        # It is githash release, force to supported Pytorch Lightning branch\n        SUPPORTED_PYTORCH_LIGHTNING = \"pytorch-lightning==1.1.5\"\n    else:\n        # Downgrade torch, pytorch-lightning\n        SUPPORTED_TORCH_VERSION = \"torch<=1.7.1\"\n        SUPPORTED_PYTORCH_LIGHTNING = \"pytorch-lightning==1.1.5\"\n\nexcept (ImportError, ModuleNotFoundError):\n    # Since no torch is installed, pip install torch will install latest torch and latest pytorch lightning\n    SUPPORTED_TORCH_VERSION = \"torch<=1.7.1\"\n    SUPPORTED_PYTORCH_LIGHTNING = \"pytorch-lightning==1.1.5\"\n\ninstall_requires_buffer = []\nfor ix, line in enumerate(install_requires):\n    if 'lightning' in line:\n        install_requires_buffer.append(SUPPORTED_PYTORCH_LIGHTNING)\n    elif 'torch' in line:\n        install_requires_buffer.append(SUPPORTED_TORCH_VERSION)\n\n        # Pytorch 1.7.1 must use torchtext==0.8.0, torchaudio==0.7.2 and torchvision==0.8.2\n        if SUPPORTED_TORCH_VERSION == \"torch<=1.7.1\":\n            install_requires_buffer.append(\"torchvision\") #\"torchvision==0.8.2\") # when we built from src in the container, it has a slightly different versions of these torch libraries\n            install_requires_buffer.append(\"torchaudio\") #\"torchaudio==0.7.2\")\n            install_requires_buffer.append(\"torchtext\") #\"torchtext==0.8.0\") \n\n    else:\n        install_requires_buffer.append(line)\n\n# override install requires\ninstall_requires = install_requires_buffer\n\n###############################################################################\n#                            Code style checkers                              #\n# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #\n\n\nclass StyleCommand(distutils_cmd.Command):\n    __LINE_WIDTH = 119\n    __ISORT_BASE = (\n        'isort '\n        # These two lines makes isort compatible with black.\n        '--multi-line=3 --trailing-comma --force-grid-wrap=0 '\n        f'--use-parentheses --line-width={__LINE_WIDTH} -rc -ws'\n    )\n    __BLACK_BASE = f'black --skip-string-normalization --line-length={__LINE_WIDTH}'\n    description = 'Checks overall project code style.'\n    user_options = [\n        ('scope=', None, 'Folder of file to operate within.'),\n        ('fix', None, 'True if tries to fix issues in-place.'),\n    ]\n\n    def __call_checker(self, base_command, scope, check):\n        command = list(base_command)\n\n        command.append(scope)\n\n        if check:\n            command.extend(['--check', '--diff'])\n\n        self.announce(\n            msg='Running command: %s' % str(' '.join(command)), level=distutils_log.INFO,\n        )\n\n        return_code = subprocess.call(command)\n\n        return return_code\n\n    def _isort(self, scope, check):\n        return self.__call_checker(base_command=self.__ISORT_BASE.split(), scope=scope, check=check,)\n\n    def _black(self, scope, check):\n        return self.__call_checker(base_command=self.__BLACK_BASE.split(), scope=scope, check=check,)\n\n    def _pass(self):\n        self.announce(msg='\\033[32mPASS\\x1b[0m', level=distutils_log.INFO)\n\n    def _fail(self):\n        self.announce(msg='\\033[31mFAIL\\x1b[0m', level=distutils_log.INFO)\n\n    # noinspection PyAttributeOutsideInit\n    def initialize_options(self):\n        self.scope = '.'\n        self.fix = ''\n\n    def run(self):\n        scope, check = self.scope, not self.fix\n        isort_return = self._isort(scope=scope, check=check)\n        black_return = self._black(scope=scope, check=check)\n\n        if isort_return == 0 and black_return == 0:\n            self._pass()\n        else:\n            self._fail()\n            exit(isort_return if isort_return != 0 else black_return)\n\n    def finalize_options(self):\n        pass\n\n\n###############################################################################\n\nsetuptools.setup(\n    name=__package_name__,\n    # Versions should comply with PEP440.  For a discussion on single-sourcing\n    # the version across setup.py and the project code, see\n    # https://packaging.python.org/en/latest/single_source_version.html\n    version=__version__,\n    description=__description__,\n    long_description=long_description,\n    long_description_content_type=long_description_content_type,\n    # The project's main homepage.\n    url=__repository_url__,\n    download_url=__download_url__,\n    # Author details\n    author=__contact_names__,\n    author_email=__contact_emails__,\n    # maintainer Details\n    maintainer=__contact_names__,\n    maintainer_email=__contact_emails__,\n    # The licence under which the project is released\n    license=__license__,\n    classifiers=[\n        # How mature is this project? Common values are\n        #  1 - Planning\n        #  2 - Pre-Alpha\n        #  3 - Alpha\n        #  4 - Beta\n        #  5 - Production/Stable\n        #  6 - Mature\n        #  7 - Inactive\n        'Development Status :: 4 - Beta',\n        # Indicate who your project is intended for\n        'Intended Audience :: Developers',\n        'Intended Audience :: Science/Research',\n        'Intended Audience :: Information Technology',\n        # Indicate what your project relates to\n        'Topic :: Scientific/Engineering',\n        'Topic :: Scientific/Engineering :: Mathematics',\n        'Topic :: Scientific/Engineering :: Image Recognition',\n        'Topic :: Scientific/Engineering :: Artificial Intelligence',\n        'Topic :: Software Development :: Libraries',\n        'Topic :: Software Development :: Libraries :: Python Modules',\n        'Topic :: Utilities',\n        # Pick your license as you wish (should match \"license\" above)\n        'License :: OSI Approved :: Apache Software License',\n        # Supported python versions\n        'Programming Language :: Python :: 3',\n        'Programming Language :: Python :: 3.5',\n        'Programming Language :: Python :: 3.6',\n        'Programming Language :: Python :: 3.7',\n        'Programming Language :: Python :: 3.8',\n        # Additional Setting\n        'Environment :: Console',\n        'Natural Language :: English',\n        'Operating System :: OS Independent',\n    ],\n    packages=setuptools.find_packages(),\n    install_requires=install_requires,\n    setup_requires=['pytest-runner'],\n    tests_require=tests_requirements,\n    # List additional groups of dependencies here (e.g. development\n    # dependencies). You can install these using the following syntax,\n    # $ pip install -e \".[all]\"\n    # $ pip install nemo_toolkit[all]\n    extras_require=extras_require,\n    # Add in any packaged data.\n    include_package_data=True,\n    zip_safe=False,\n    # PyPI package information.\n    keywords=__keywords__,\n    # Custom commands.\n    cmdclass={'style': StyleCommand},\n)\n"
  },
  {
    "path": "patches/nemo/1.6.2/requirements.original.txt",
    "content": "numpy>=1.21\nonnx>=1.7.0\npython-dateutil\ntorch\nwrapt\nruamel.yaml\nscikit-learn\nsentencepiece<1.0.0\ntqdm>=4.41.0\nnumba\nwget\nfrozendict\nunidecode\n"
  },
  {
    "path": "patches/nemo/1.6.2/requirements.txt",
    "content": "numpy\nonnx>=1.7.0\npython-dateutil\ntorch\nwrapt\nruamel.yaml\nscikit-learn\nsentencepiece<1.0.0\ntqdm>=4.41.0\nnumba\nwget\nfrozendict\nunidecode\n"
  },
  {
    "path": "patches/nemo/1.6.2/requirements_nlp.original.txt",
    "content": "boto3\nh5py\nmatplotlib>=3.3.2\nsentencepiece\nyoutokentome>=1.0.5\nnumpy\nrapidfuzz\ngdown\ninflect\nsacrebleu[ja]\nsacremoses>=0.0.43\nnltk>=3.6.5\nfasttext\nopencc\npangu\njieba\nftfy\n"
  },
  {
    "path": "patches/nemo/1.6.2/requirements_nlp.txt",
    "content": "boto3\nh5py\nmatplotlib\nsentencepiece\nyoutokentome>=1.0.5\nnumpy\ngdown\ninflect\nsacremoses>=0.0.43\nnltk>=3.6.5\nfasttext\nopencc\npangu\njieba\nftfy\n"
  },
  {
    "path": "patches/pytorch/1.6.0/functional.diff",
    "content": "2a3,5\n> import librosa  # STFT patch for aarch64\n> import numpy as np\n> \n465c468,478\n<     return _VF.stft(input, n_fft, hop_length, win_length, window, normalized, onesided)\n---\n>         \n>     # STFT patch for aarch64\n>     # https://stackoverflow.com/a/66872148\n>     librosa_stft = librosa.stft(input.cpu().detach().numpy().reshape(-1), n_fft, hop_length, win_length, window=\"hann\", center=center, pad_mode=pad_mode)\n>     librosa_stft = np.array([[a.real, a.imag] for a in librosa_stft])\n>     librosa_stft = np.transpose(librosa_stft, axes=[0, 2, 1])\n>     librosa_stft = np.expand_dims(librosa_stft, 0)\n>     librosa_stft = torch.from_numpy(librosa_stft)\n>     return librosa_stft\n>     #return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore\n>     #                normalized, onesided, return_complex)\n"
  },
  {
    "path": "patches/pytorch/1.6.0/functional.original.py",
    "content": "from typing import Tuple, Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom ._lowrank import svd_lowrank, pca_lowrank\nfrom ._overrides import has_torch_function, handle_torch_function\nfrom ._jit_internal import boolean_dispatch, List\nfrom ._jit_internal import _overload as overload\n\nTensor = torch.Tensor\nfrom torch import _VF\n\n__all__ = [\n    'align_tensors',\n    'broadcast_tensors',\n    'cartesian_prod',\n    'block_diag',\n    'cdist',\n    'chain_matmul',\n    'einsum',\n    'istft',\n    'lu',\n    'lu_unpack',\n    'norm',\n    'meshgrid',\n    'pca_lowrank',\n    'split',\n    'stft',\n    'svd_lowrank',\n    'tensordot',\n    'unique',\n    'unique_consecutive',\n]\n\n\ndef broadcast_tensors(*tensors):\n    r\"\"\"broadcast_tensors(*tensors) -> List of Tensors\n\n    Broadcasts the given tensors according to :ref:`broadcasting-semantics`.\n\n    Args:\n        *tensors: any number of tensors of the same type\n\n    .. warning::\n\n        More than one element of a broadcasted tensor may refer to a single\n        memory location. As a result, in-place operations (especially ones that\n        are vectorized) may result in incorrect behavior. If you need to write\n        to the tensors, please clone them first.\n\n    Example::\n\n        >>> x = torch.arange(3).view(1, 3)\n        >>> y = torch.arange(2).view(2, 1)\n        >>> a, b = torch.broadcast_tensors(x, y)\n        >>> a.size()\n        torch.Size([2, 3])\n        >>> a\n        tensor([[0, 1, 2],\n                [0, 1, 2]])\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):\n            return handle_torch_function(broadcast_tensors, tensors, *tensors)\n    return _VF.broadcast_tensors(tensors)\n\n\ndef split(tensor, split_size_or_sections, dim=0):\n    r\"\"\"Splits the tensor into chunks. Each chunk is a view of the original tensor.\n\n    If :attr:`split_size_or_sections` is an integer type, then :attr:`tensor` will\n    be split into equally sized chunks (if possible). Last chunk will be smaller if\n    the tensor size along the given dimension :attr:`dim` is not divisible by\n    :attr:`split_size`.\n\n    If :attr:`split_size_or_sections` is a list, then :attr:`tensor` will be split\n    into ``len(split_size_or_sections)`` chunks with sizes in :attr:`dim` according\n    to :attr:`split_size_or_sections`.\n\n    Arguments:\n        tensor (Tensor): tensor to split.\n        split_size_or_sections (int) or (list(int)): size of a single chunk or\n            list of sizes for each chunk\n        dim (int): dimension along which to split the tensor.\n\n    Example::\n        >>> a = torch.arange(10).reshape(5,2)\n        >>> a\n        tensor([[0, 1],\n                [2, 3],\n                [4, 5],\n                [6, 7],\n                [8, 9]])\n        >>> torch.split(a, 2)\n        (tensor([[0, 1],\n                 [2, 3]]),\n         tensor([[4, 5],\n                 [6, 7]]),\n         tensor([[8, 9]]))\n        >>> torch.split(a, [1,4])\n        (tensor([[0, 1]]),\n         tensor([[2, 3],\n                 [4, 5],\n                 [6, 7],\n                 [8, 9]]))\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if type(tensor) is not Tensor and has_torch_function((tensor,)):\n            return handle_torch_function(split, (tensor,), tensor, split_size_or_sections,\n                                         dim=dim)\n    # Overwriting reason:\n    # This dispatches to two ATen functions depending on the type of\n    # split_size_or_sections. The branching code is in tensor.py, which we\n    # call here.\n    return tensor.split(split_size_or_sections, dim)\n\n# equivalent to itertools.product(indices)\ndef _indices_product(indices):\n    # type: (List[int]) -> (List[List[int]])\n    empty_list = torch.jit.annotate(List[int], [])\n    result = [empty_list]\n    for idx in indices:\n        result_temp = torch.jit.annotate(List[List[int]], [])\n        for res in result:\n            for i in range(idx):\n                result_temp.append(res + [i])\n        result = result_temp\n    return result\n\ndef _index_tensor_with_indices_list(tensor, indices):\n    # type: (Tensor, List[int]) -> Tensor\n    out = tensor\n    for index in indices:\n        out = out[index]\n    return out\n\ndef lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):\n    # type: (Tensor, Tensor, bool, bool) ->  (Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]])\n    r\"\"\"Unpacks the data and pivots from a LU factorization of a tensor.\n\n    Returns a tuple of tensors as ``(the pivots, the L tensor, the U tensor)``.\n\n    Arguments:\n        LU_data (Tensor): the packed LU factorization data\n        LU_pivots (Tensor): the packed LU factorization pivots\n        unpack_data (bool): flag indicating if the data should be unpacked\n        unpack_pivots (bool): flag indicating if the pivots should be unpacked\n\n    Examples::\n\n        >>> A = torch.randn(2, 3, 3)\n        >>> A_LU, pivots = A.lu()\n        >>> P, A_L, A_U = torch.lu_unpack(A_LU, pivots)\n        >>>\n        >>> # can recover A from factorization\n        >>> A_ = torch.bmm(P, torch.bmm(A_L, A_U))\n\n        >>> # LU factorization of a rectangular matrix:\n        >>> A = torch.randn(2, 3, 2)\n        >>> A_LU, pivots = A.lu()\n        >>> P, A_L, A_U = torch.lu_unpack(A_LU, pivots)\n        >>> P\n        tensor([[[1., 0., 0.],\n                 [0., 1., 0.],\n                 [0., 0., 1.]],\n\n                [[0., 0., 1.],\n                 [0., 1., 0.],\n                 [1., 0., 0.]]])\n        >>> A_L\n        tensor([[[ 1.0000,  0.0000],\n                 [ 0.4763,  1.0000],\n                 [ 0.3683,  0.1135]],\n\n                [[ 1.0000,  0.0000],\n                 [ 0.2957,  1.0000],\n                 [-0.9668, -0.3335]]])\n        >>> A_U\n        tensor([[[ 2.1962,  1.0881],\n                 [ 0.0000, -0.8681]],\n\n                [[-1.0947,  0.3736],\n                 [ 0.0000,  0.5718]]])\n        >>> A_ = torch.bmm(P, torch.bmm(A_L, A_U))\n        >>> torch.norm(A_ - A)\n        tensor(2.9802e-08)\n    \"\"\"\n    if not torch.jit.is_scripting():\n        tens_ops = (LU_data, LU_pivots)\n        if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):\n            return handle_torch_function(\n                lu_unpack, tens_ops, LU_data, LU_pivots, unpack_data=unpack_data,\n                unpack_pivots=unpack_pivots)\n    shape = LU_data.shape\n    # In generalized LU factorization, the following shape relations hold:\n    #   A.shape[-2:] == (m, n)\n    #   P.shape[-2:] == (m, m)\n    #   L.shape[-2:] == (m, k)\n    #   U.shape[-2:] == (k, n)\n    # where k = min(m, n)\n    m, n = shape[-2:]\n    k = min(m, n)\n    if unpack_data:\n        U = LU_data.triu()\n        if m != k:\n            U = U.narrow(-2, 0, k)\n        L = LU_data.tril()\n        if k != n:\n            L = L.narrow(-1, 0, k)\n        L.diagonal(dim1=-2, dim2=-1).fill_(1)\n    else:\n        L = U = None\n\n    if unpack_pivots:\n        LU_pivots_zero_idx = LU_pivots - 1\n        if LU_data.dim() > 2:\n            P = torch.eye(m, device=LU_data.device, dtype=LU_data.dtype) \\\n                     .expand(shape[:-1] + (m,)) \\\n                     .clone(memory_format=torch.contiguous_format)\n\n            # TODO: rewrite when TorchScript supports product and map as\n            # product(*map(lambda x: list(range(x)), shape[:-2])) when issue 33781 is fixed\n            indices = _indices_product(shape[:-2])\n            for idx in indices:\n                final_order = [i for i in range(m)]  # noqa: C416 TODO: rewrite as list(range(m))\n                for k, j in enumerate(_index_tensor_with_indices_list(LU_pivots_zero_idx, idx)):\n                    final_order[k], final_order[j] = final_order[j], final_order[k]\n                # TODO: remove _index_tensor_with_indices_list when TorchScript supports indexing Tensor with list\n                p_idx = _index_tensor_with_indices_list(P, idx)\n                p_idx.copy_(p_idx.index_select(1, torch.as_tensor(final_order, device=LU_pivots.device)))\n        else:\n            P = torch.eye(m, device=LU_data.device, dtype=LU_data.dtype)\n            final_order = [i for i in range(m)]  # noqa: C416 TODO: rewrite as list(range(m))\n            for k, j, in enumerate(LU_pivots_zero_idx):\n                final_order[k], final_order[j] = final_order[j], final_order[k]\n            P = P.index_select(1, torch.as_tensor(final_order, device=LU_pivots.device))\n    else:\n        P = None\n\n    return P, L, U\n\n\ndef einsum(equation, *operands):\n    r\"\"\"einsum(equation, *operands) -> Tensor\n\nThis function provides a way of computing multilinear expressions (i.e. sums of products) using the\nEinstein summation convention.\n\nArgs:\n    equation (string): The equation is given in terms of lower case letters (indices) to be associated\n           with each dimension of the operands and result. The left hand side lists the operands\n           dimensions, separated by commas. There should be one index letter per tensor dimension.\n           The right hand side follows after `->` and gives the indices for the output.\n           If the `->` and right hand side are omitted, it implicitly defined as the alphabetically\n           sorted list of all indices appearing exactly once in the left hand side.\n           The indices not apprearing in the output are summed over after multiplying the operands\n           entries.\n           If an index appears several times for the same operand, a diagonal is taken.\n           Ellipses `...` represent a fixed number of dimensions. If the right hand side is inferred,\n           the ellipsis dimensions are at the beginning of the output.\n    operands (Tensor): The operands to compute the Einstein sum of.\n\n.. note::\n\n    This function does not optimize the given expression, so a different formula for the same computation may\n    run faster or consume less memory. Projects like opt_einsum (https://optimized-einsum.readthedocs.io/en/stable/)\n    can optimize the formula for you.\n\nExamples::\n\n    >>> x = torch.randn(5)\n    >>> y = torch.randn(4)\n    >>> torch.einsum('i,j->ij', x, y)  # outer product\n    tensor([[-0.0570, -0.0286, -0.0231,  0.0197],\n            [ 1.2616,  0.6335,  0.5113, -0.4351],\n            [ 1.4452,  0.7257,  0.5857, -0.4984],\n            [-0.4647, -0.2333, -0.1883,  0.1603],\n            [-1.1130, -0.5588, -0.4510,  0.3838]])\n\n\n    >>> A = torch.randn(3,5,4)\n    >>> l = torch.randn(2,5)\n    >>> r = torch.randn(2,4)\n    >>> torch.einsum('bn,anm,bm->ba', l, A, r) # compare torch.nn.functional.bilinear\n    tensor([[-0.3430, -5.2405,  0.4494],\n            [ 0.3311,  5.5201, -3.0356]])\n\n\n    >>> As = torch.randn(3,2,5)\n    >>> Bs = torch.randn(3,5,4)\n    >>> torch.einsum('bij,bjk->bik', As, Bs) # batch matrix multiplication\n    tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],\n             [-1.6706, -0.8097, -0.8025, -2.1183]],\n\n            [[ 4.2239,  0.3107, -0.5756, -0.2354],\n             [-1.4558, -0.3460,  1.5087, -0.8530]],\n\n            [[ 2.8153,  1.8787, -4.3839, -1.2112],\n             [ 0.3728, -2.1131,  0.0921,  0.8305]]])\n\n    >>> A = torch.randn(3, 3)\n    >>> torch.einsum('ii->i', A) # diagonal\n    tensor([-0.7825,  0.8291, -0.1936])\n\n    >>> A = torch.randn(4, 3, 3)\n    >>> torch.einsum('...ii->...i', A) # batch diagonal\n    tensor([[-1.0864,  0.7292,  0.0569],\n            [-0.9725, -1.0270,  0.6493],\n            [ 0.5832, -1.1716, -1.5084],\n            [ 0.4041, -1.1690,  0.8570]])\n\n    >>> A = torch.randn(2, 3, 4, 5)\n    >>> torch.einsum('...ij->...ji', A).shape # batch permute\n    torch.Size([2, 3, 5, 4])\n\"\"\"\n    if not torch.jit.is_scripting():\n        if any(type(t) is not Tensor for t in operands) and has_torch_function(operands):\n            return handle_torch_function(einsum, operands, equation, *operands)\n\n    if len(operands) == 1 and isinstance(operands[0], (list, tuple)):\n        # the old interface of passing the operands as one list argument\n        operands = operands[0]\n        # recurse incase operands contains value that has torch function\n        # in the original implementation this line is omitted\n        return einsum(equation, *operands)\n\n    return _VF.einsum(equation, operands)\n\n\ndef meshgrid(*tensors):\n    r\"\"\"Take :math:`N` tensors, each of which can be either scalar or 1-dimensional\nvector, and create :math:`N` N-dimensional grids, where the :math:`i` :sup:`th` grid is defined by\nexpanding the :math:`i` :sup:`th` input over dimensions defined by other inputs.\n\n\n    Args:\n        tensors (list of Tensor): list of scalars or 1 dimensional tensors. Scalars will be\n        treated as tensors of size :math:`(1,)` automatically\n\n    Returns:\n        seq (sequence of Tensors): If the input has :math:`k` tensors of size\n        :math:`(N_1,), (N_2,), \\ldots , (N_k,)`, then the output would also have :math:`k` tensors,\n        where all tensors are of size :math:`(N_1, N_2, \\ldots , N_k)`.\n\n    Example::\n\n        >>> x = torch.tensor([1, 2, 3])\n        >>> y = torch.tensor([4, 5, 6])\n        >>> grid_x, grid_y = torch.meshgrid(x, y)\n        >>> grid_x\n        tensor([[1, 1, 1],\n                [2, 2, 2],\n                [3, 3, 3]])\n        >>> grid_y\n        tensor([[4, 5, 6],\n                [4, 5, 6],\n                [4, 5, 6]])\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):\n            return handle_torch_function(meshgrid, tensors, *tensors)\n    if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)):\n        # the old interface of passing the operands as one list argument\n        tensors = tensors[0]\n    return _VF.meshgrid(tensors)\n\n\ndef stft(input, n_fft, hop_length=None, win_length=None, window=None,\n         center=True, pad_mode='reflect', normalized=False, onesided=True):\n    # type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, str, bool, bool) -> Tensor\n    r\"\"\"Short-time Fourier transform (STFT).\n\n    Ignoring the optional batch dimension, this method computes the following\n    expression:\n\n    .. math::\n        X[m, \\omega] = \\sum_{k = 0}^{\\text{win\\_length-1}}%\n                            \\text{window}[k]\\ \\text{input}[m \\times \\text{hop\\_length} + k]\\ %\n                            \\exp\\left(- j \\frac{2 \\pi \\cdot \\omega k}{\\text{win\\_length}}\\right),\n\n    where :math:`m` is the index of the sliding window, and :math:`\\omega` is\n    the frequency that :math:`0 \\leq \\omega < \\text{n\\_fft}`. When\n    :attr:`onesided` is the default value ``True``,\n\n    * :attr:`input` must be either a 1-D time sequence or a 2-D batch of time\n      sequences.\n\n    * If :attr:`hop_length` is ``None`` (default), it is treated as equal to\n      ``floor(n_fft / 4)``.\n\n    * If :attr:`win_length` is ``None`` (default), it is treated as equal to\n      :attr:`n_fft`.\n\n    * :attr:`window` can be a 1-D tensor of size :attr:`win_length`, e.g., from\n      :meth:`torch.hann_window`. If :attr:`window` is ``None`` (default), it is\n      treated as if having :math:`1` everywhere in the window. If\n      :math:`\\text{win\\_length} < \\text{n\\_fft}`, :attr:`window` will be padded on\n      both sides to length :attr:`n_fft` before being applied.\n\n    * If :attr:`center` is ``True`` (default), :attr:`input` will be padded on\n      both sides so that the :math:`t`-th frame is centered at time\n      :math:`t \\times \\text{hop\\_length}`. Otherwise, the :math:`t`-th frame\n      begins at time  :math:`t \\times \\text{hop\\_length}`.\n\n    * :attr:`pad_mode` determines the padding method used on :attr:`input` when\n      :attr:`center` is ``True``. See :meth:`torch.nn.functional.pad` for\n      all available options. Default is ``\"reflect\"``.\n\n    * If :attr:`onesided` is ``True`` (default), only values for :math:`\\omega`\n      in :math:`\\left[0, 1, 2, \\dots, \\left\\lfloor \\frac{\\text{n\\_fft}}{2} \\right\\rfloor + 1\\right]`\n      are returned because the real-to-complex Fourier transform satisfies the\n      conjugate symmetry, i.e., :math:`X[m, \\omega] = X[m, \\text{n\\_fft} - \\omega]^*`.\n\n    * If :attr:`normalized` is ``True`` (default is ``False``), the function\n      returns the normalized STFT results, i.e., multiplied by :math:`(\\text{frame\\_length})^{-0.5}`.\n\n    Returns the real and the imaginary parts together as one tensor of size\n    :math:`(* \\times N \\times T \\times 2)`, where :math:`*` is the optional\n    batch size of :attr:`input`, :math:`N` is the number of frequencies where\n    STFT is applied, :math:`T` is the total number of frames used, and each pair\n    in the last dimension represents a complex number as the real part and the\n    imaginary part.\n\n    .. warning::\n      This function changed signature at version 0.4.1. Calling with the\n      previous signature may cause error or return incorrect result.\n\n    Arguments:\n        input (Tensor): the input tensor\n        n_fft (int): size of Fourier transform\n        hop_length (int, optional): the distance between neighboring sliding window\n            frames. Default: ``None`` (treated as equal to ``floor(n_fft / 4)``)\n        win_length (int, optional): the size of window frame and STFT filter.\n            Default: ``None``  (treated as equal to :attr:`n_fft`)\n        window (Tensor, optional): the optional window function.\n            Default: ``None`` (treated as window of all :math:`1` s)\n        center (bool, optional): whether to pad :attr:`input` on both sides so\n            that the :math:`t`-th frame is centered at time :math:`t \\times \\text{hop\\_length}`.\n            Default: ``True``\n        pad_mode (string, optional): controls the padding method used when\n            :attr:`center` is ``True``. Default: ``\"reflect\"``\n        normalized (bool, optional): controls whether to return the normalized STFT results\n             Default: ``False``\n        onesided (bool, optional): controls whether to return half of results to\n            avoid redundancy Default: ``True``\n\n    Returns:\n        Tensor: A tensor containing the STFT result with shape described above\n\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return handle_torch_function(\n                stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length,\n                window=window, center=center, pad_mode=pad_mode, normalized=normalized,\n                onesided=onesided)\n    # TODO: after having proper ways to map Python strings to ATen Enum, move\n    #       this and F.pad to ATen.\n    if center:\n        signal_dim = input.dim()\n        extended_shape = [1] * (3 - signal_dim) + list(input.size())\n        pad = int(n_fft // 2)\n        input = F.pad(input.view(extended_shape), (pad, pad), pad_mode)\n        input = input.view(input.shape[-signal_dim:])\n    return _VF.stft(input, n_fft, hop_length, win_length, window, normalized, onesided)\n\n\ndef istft(input, n_fft, hop_length=None, win_length=None, window=None,\n          center=True, normalized=False, onesided=True, length=None):\n    # type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, bool, bool, Optional[int]) -> Tensor\n    r\"\"\"Inverse short time Fourier Transform. This is expected to be the inverse of :func:`~torch.stft`.\n    It has the same parameters (+ additional optional parameter of :attr:`length`) and it should return the\n    least squares estimation of the original signal. The algorithm will check using the NOLA condition (\n    nonzero overlap).\n\n    Important consideration in the parameters :attr:`window` and :attr:`center` so that the envelop\n    created by the summation of all the windows is never zero at certain point in time. Specifically,\n    :math:`\\sum_{t=-\\infty}^{\\infty} w^2[n-t\\times hop\\_length] \\cancel{=} 0`.\n\n    Since :func:`~torch.stft` discards elements at the end of the signal if they do not fit in a frame,\n    ``istft`` may return a shorter signal than the original signal (can occur if :attr:`center` is False\n    since the signal isn't padded).\n\n    If :attr:`center` is ``True``, then there will be padding e.g. ``'constant'``, ``'reflect'``, etc.\n    Left padding can be trimmed off exactly because they can be calculated but right padding cannot be\n    calculated without additional information.\n\n    Example: Suppose the last window is:\n    ``[17, 18, 0, 0, 0]`` vs ``[18, 0, 0, 0, 0]``\n\n    The :attr:`n_fft`, :attr:`hop_length`, :attr:`win_length` are all the same which prevents the calculation\n    of right padding. These additional values could be zeros or a reflection of the signal so providing\n    :attr:`length` could be useful. If :attr:`length` is ``None`` then padding will be aggressively removed\n    (some loss of signal).\n\n    [1] D. W. Griffin and J. S. Lim, \"Signal estimation from modified short-time Fourier transform,\"\n    IEEE Trans. ASSP, vol.32, no.2, pp.236-243, Apr. 1984.\n\n    Arguments:\n        input (Tensor): The input tensor. Expected to be output of :func:`~torch.stft`,\n            either 3D (``fft_size``, ``n_frame``, 2) or 4D (``channel``, ``fft_size``, ``n_frame``, 2).\n        n_fft (int): Size of Fourier transform\n        hop_length (Optional[int]): The distance between neighboring sliding window frames.\n            (Default: ``n_fft // 4``)\n        win_length (Optional[int]): The size of window frame and STFT filter. (Default: ``n_fft``)\n        window (Optional[torch.Tensor]): The optional window function.\n            (Default: ``torch.ones(win_length)``)\n        center (bool): Whether :attr:`input` was padded on both sides so that the :math:`t`-th frame is\n            centered at time :math:`t \\times \\text{hop\\_length}`.\n            (Default: ``True``)\n        normalized (bool): Whether the STFT was normalized. (Default: ``False``)\n        onesided (bool): Whether the STFT is onesided. (Default: ``True``)\n        length (Optional[int]): The amount to trim the signal by (i.e. the\n            original signal length). (Default: whole signal)\n\n    Returns:\n        Tensor: Least squares estimation of the original signal of size (..., signal_length)\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return handle_torch_function(\n                istft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length,\n                window=window, center=center, normalized=normalized, onesided=onesided,\n                length=length)\n\n    return _VF.istft(\n        input, n_fft, hop_length, win_length, window, center, normalized, onesided, length)\n\n\ndel torch.unique_dim\n\n\ndef _unique_impl(input, sorted=True, return_inverse=False, return_counts=False, dim=None):\n    # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor, Tensor]\n    r\"\"\"Returns the unique elements of the input tensor.\n\n    .. note:: This function is different from :func:`torch.unique_consecutive` in the sense that\n        this function also eliminates non-consecutive duplicate values.\n\n    .. note:: Currently in the CUDA implementation and the CPU implementation when dim is specified,\n        `torch.unique` always sort the tensor at the beginning regardless of the `sort` argument.\n        Sorting could be slow, so if your input tensor is already sorted, it is recommended to use\n        :func:`torch.unique_consecutive` which avoids the sorting.\n\n    Arguments:\n        input (Tensor): the input tensor\n        sorted (bool): Whether to sort the unique elements in ascending order\n            before returning as output.\n        return_inverse (bool): Whether to also return the indices for where\n            elements in the original input ended up in the returned unique list.\n        return_counts (bool): Whether to also return the counts for each unique\n            element.\n        dim (int): the dimension to apply unique. If ``None``, the unique of the\n            flattened input is returned. default: ``None``\n\n    Returns:\n        (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing\n\n            - **output** (*Tensor*): the output list of unique scalar elements.\n            - **inverse_indices** (*Tensor*): (optional) if\n              :attr:`return_inverse` is True, there will be an additional\n              returned tensor (same shape as input) representing the indices\n              for where elements in the original input map to in the output;\n              otherwise, this function will only return a single tensor.\n            - **counts** (*Tensor*): (optional) if\n              :attr:`return_counts` is True, there will be an additional\n              returned tensor (same shape as output or output.size(dim),\n              if dim was specified) representing the number of occurrences\n              for each unique value or tensor.\n\n    Example::\n\n        >>> output = torch.unique(torch.tensor([1, 3, 2, 3], dtype=torch.long))\n        >>> output\n        tensor([ 2,  3,  1])\n\n        >>> output, inverse_indices = torch.unique(\n                torch.tensor([1, 3, 2, 3], dtype=torch.long), sorted=True, return_inverse=True)\n        >>> output\n        tensor([ 1,  2,  3])\n        >>> inverse_indices\n        tensor([ 0,  2,  1,  2])\n\n        >>> output, inverse_indices = torch.unique(\n                torch.tensor([[1, 3], [2, 3]], dtype=torch.long), sorted=True, return_inverse=True)\n        >>> output\n        tensor([ 1,  2,  3])\n        >>> inverse_indices\n        tensor([[ 0,  2],\n                [ 1,  2]])\n\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return handle_torch_function(\n                unique, (input,), input, sorted=sorted, return_inverse=return_inverse,\n                return_counts=return_counts, dim=dim)\n\n    if dim is not None:\n        output, inverse_indices, counts = _VF.unique_dim(\n            input,\n            dim,\n            sorted=sorted,\n            return_inverse=return_inverse,\n            return_counts=return_counts,\n        )\n    else:\n        output, inverse_indices, counts = torch._unique2(\n            input,\n            sorted=sorted,\n            return_inverse=return_inverse,\n            return_counts=return_counts,\n        )\n    return output, inverse_indices, counts\n\n\ndef _unique_consecutive_impl(input, return_inverse=False, return_counts=False, dim=None):\n    # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor, Tensor]\n    r\"\"\"Eliminates all but the first element from every consecutive group of equivalent elements.\n\n    .. note:: This function is different from :func:`torch.unique` in the sense that this function\n        only eliminates consecutive duplicate values. This semantics is similar to `std::unique`\n        in C++.\n\n    Arguments:\n        input (Tensor): the input tensor\n        return_inverse (bool): Whether to also return the indices for where\n            elements in the original input ended up in the returned unique list.\n        return_counts (bool): Whether to also return the counts for each unique\n            element.\n        dim (int): the dimension to apply unique. If ``None``, the unique of the\n            flattened input is returned. default: ``None``\n\n    Returns:\n        (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing\n\n            - **output** (*Tensor*): the output list of unique scalar elements.\n            - **inverse_indices** (*Tensor*): (optional) if\n              :attr:`return_inverse` is True, there will be an additional\n              returned tensor (same shape as input) representing the indices\n              for where elements in the original input map to in the output;\n              otherwise, this function will only return a single tensor.\n            - **counts** (*Tensor*): (optional) if\n              :attr:`return_counts` is True, there will be an additional\n              returned tensor (same shape as output or output.size(dim),\n              if dim was specified) representing the number of occurrences\n              for each unique value or tensor.\n\n    Example::\n\n        >>> x = torch.tensor([1, 1, 2, 2, 3, 1, 1, 2])\n        >>> output = torch.unique_consecutive(x)\n        >>> output\n        tensor([1, 2, 3, 1, 2])\n\n        >>> output, inverse_indices = torch.unique_consecutive(x, return_inverse=True)\n        >>> output\n        tensor([1, 2, 3, 1, 2])\n        >>> inverse_indices\n        tensor([0, 0, 1, 1, 2, 3, 3, 4])\n\n        >>> output, counts = torch.unique_consecutive(x, return_counts=True)\n        >>> output\n        tensor([1, 2, 3, 1, 2])\n        >>> counts\n        tensor([2, 2, 1, 2, 1])\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return handle_torch_function(\n                unique_consecutive, (input,), input, return_inverse=return_inverse,\n                return_counts=return_counts, dim=dim)\n    output, inverse_indices, counts = _VF.unique_consecutive(\n        input, return_inverse=return_inverse, return_counts=return_counts, dim=dim)\n    return output, inverse_indices, counts\n\n\ndef _return_counts(input, sorted=True, return_inverse=False, return_counts=False, dim=None):\n    # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]\n\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return _unique_impl(input, sorted, return_inverse, return_counts, dim)\n\n    output, _, counts = _unique_impl(input, sorted, return_inverse, return_counts, dim)\n    return output, counts\n\ndef _return_output(input, sorted=True, return_inverse=False, return_counts=False, dim=None):\n    # type: (Tensor, bool, bool, bool, Optional[int]) -> Tensor\n\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return _unique_impl(input, sorted, return_inverse, return_counts, dim)\n\n    output, _, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)\n    return output\n\ndef _return_inverse(input, sorted=True, return_inverse=False, return_counts=False, dim=None):\n    # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]\n\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return _unique_impl(input, sorted, return_inverse, return_counts, dim)\n\n    output, inverse_indices, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)\n    return output, inverse_indices\n\n_return_inverse_false = boolean_dispatch(\n    arg_name='return_counts',\n    arg_index=3,\n    default=False,\n    if_true=_return_counts,\n    if_false=_return_output,\n    module_name=__name__,\n    func_name='unique')\n\n_return_inverse_true = boolean_dispatch(\n    arg_name='return_counts',\n    arg_index=3,\n    default=False,\n    if_true=_unique_impl,\n    if_false=_return_inverse,\n    module_name=__name__,\n    func_name='unique')\n\n# The return type of unique depends on `return_inverse`, and `return_counts` so in order to\n# resolve the output type in TorchScript we need to statically know the value of both parameters\n\nunique = boolean_dispatch(\n    arg_name='return_inverse',\n    arg_index=2,\n    default=False,\n    if_true=_return_inverse_true,\n    if_false=_return_inverse_false,\n    module_name=__name__,\n    func_name='unique')\nunique.__doc__ = _unique_impl.__doc__\n\n\ndef _consecutive_return_counts(input, return_inverse=False, return_counts=False, dim=None):\n    # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]\n\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return _unique_consecutive_impl(input, return_inverse, return_counts, dim)\n\n    output, _, counts = _unique_consecutive_impl(input, return_inverse, return_counts, dim)\n    return output, counts\n\ndef _consecutive_return_output(input, return_inverse=False, return_counts=False, dim=None):\n    # type: (Tensor, bool, bool, Optional[int]) -> Tensor\n\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return _unique_consecutive_impl(input, return_inverse, return_counts, dim)\n\n    output, _, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)\n    return output\n\ndef _consecutive_return_inverse(input, return_inverse=False, return_counts=False, dim=None):\n    # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]\n\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return _unique_consecutive_impl(input, return_inverse, return_counts, dim)\n\n    output, inverse_indices, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)\n    return output, inverse_indices\n\n_consecutive_return_inverse_false = boolean_dispatch(\n    arg_name='return_counts',\n    arg_index=1,\n    default=False,\n    if_true=_consecutive_return_counts,\n    if_false=_consecutive_return_output,\n    module_name=__name__,\n    func_name='unique_consecutive')\n\n_consecutive_return_inverse_true = boolean_dispatch(\n    arg_name='return_counts',\n    arg_index=1,\n    default=False,\n    if_true=_unique_consecutive_impl,\n    if_false=_consecutive_return_inverse,\n    module_name=__name__,\n    func_name='unique_consecutive')\n\n# The return type of unique depends on `return_inverse`, and `return_counts` so in order to\n# resolve the output type in TorchScript we need to statically know the value of both parameters\n\nunique_consecutive = boolean_dispatch(\n    arg_name='return_inverse',\n    arg_index=2,\n    default=False,\n    if_true=_consecutive_return_inverse_true,\n    if_false=_consecutive_return_inverse_false,\n    module_name=__name__,\n    func_name='unique_consecutive')\nunique_consecutive.__doc__ = _unique_consecutive_impl.__doc__\n\n\ndef tensordot(a, b, dims=2):\n    r\"\"\"Returns a contraction of a and b over multiple dimensions.\n\n    :attr:`tensordot` implements a generalized matrix product.\n\n    Args:\n      a (Tensor): Left tensor to contract\n      b (Tensor): Right tensor to contract\n      dims (int or tuple of two lists of integers): number of dimensions to\n         contract or explicit lists of dimensions for :attr:`a` and\n         :attr:`b` respectively\n\n    When called with a non-negative integer argument :attr:`dims` = :math:`d`, and\n    the number of dimensions of :attr:`a` and :attr:`b` is :math:`m` and :math:`n`,\n    respectively, :func:`~torch.tensordot` computes\n\n    .. math::\n        r_{i_0,...,i_{m-d}, i_d,...,i_n}\n          = \\sum_{k_0,...,k_{d-1}} a_{i_0,...,i_{m-d},k_0,...,k_{d-1}} \\times b_{k_0,...,k_{d-1}, i_d,...,i_n}.\n\n    When called with :attr:`dims` of the list form, the given dimensions will be contracted\n    in place of the last :math:`d` of :attr:`a` and the first :math:`d` of :math:`b`. The sizes\n    in these dimensions must match, but :func:`~torch.tensordot` will deal with broadcasted\n    dimensions.\n\n    Examples::\n\n        >>> a = torch.arange(60.).reshape(3, 4, 5)\n        >>> b = torch.arange(24.).reshape(4, 3, 2)\n        >>> torch.tensordot(a, b, dims=([1, 0], [0, 1]))\n        tensor([[4400., 4730.],\n                [4532., 4874.],\n                [4664., 5018.],\n                [4796., 5162.],\n                [4928., 5306.]])\n\n        >>> a = torch.randn(3, 4, 5, device='cuda')\n        >>> b = torch.randn(4, 5, 6, device='cuda')\n        >>> c = torch.tensordot(a, b, dims=2).cpu()\n        tensor([[ 8.3504, -2.5436,  6.2922,  2.7556, -1.0732,  3.2741],\n                [ 3.3161,  0.0704,  5.0187, -0.4079, -4.3126,  4.8744],\n                [ 0.8223,  3.9445,  3.2168, -0.2400,  3.4117,  1.7780]])\n\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if (type(a) is not Tensor or type(b) is not Tensor) and has_torch_function((a, b)):\n            return handle_torch_function(tensordot, (a, b), a, b, dims=dims)\n    if isinstance(dims, (list, tuple)) or \\\n       (isinstance(dims, torch.Tensor) and dims.numel() > 1):\n        dims_a, dims_b = dims\n    else:\n        if isinstance(dims, torch.Tensor):\n            dims = dims.item()\n        if dims < 0:\n            raise RuntimeError(\"tensordot expects dims >= 0, but got dims={}\".format(dims))\n        dims_a = list(range(-dims, 0))\n        dims_b = list(range(dims))\n    return _VF.tensordot(a, b, dims_a, dims_b)\n\ndef cartesian_prod(*tensors):\n    \"\"\"Do cartesian product of the given sequence of tensors. The behavior is similar to\n    python's `itertools.product`.\n\n    Arguments:\n        *tensors: any number of 1 dimensional tensors.\n\n    Returns:\n        Tensor: A tensor equivalent to converting all the input tensors into lists,\n            do `itertools.product` on these lists, and finally convert the resulting list\n            into tensor.\n\n    Example::\n\n        >>> a = [1, 2, 3]\n        >>> b = [4, 5]\n        >>> list(itertools.product(a, b))\n        [(1, 4), (1, 5), (2, 4), (2, 5), (3, 4), (3, 5)]\n        >>> tensor_a = torch.tensor(a)\n        >>> tensor_b = torch.tensor(b)\n        >>> torch.cartesian_prod(tensor_a, tensor_b)\n        tensor([[1, 4],\n                [1, 5],\n                [2, 4],\n                [2, 5],\n                [3, 4],\n                [3, 5]])\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):\n            return handle_torch_function(cartesian_prod, tensors, *tensors)\n    return _VF.cartesian_prod(tensors)\n\ndef block_diag(*tensors):\n    \"\"\"Create a block diagonal matrix from provided tensors.\n\n    Arguments:\n        *tensors: One or more tensors with 0, 1, or 2 dimensions.\n\n    Returns:\n        Tensor: A 2 dimensional tensor with all the input tensors arranged in\n            order such that their upper left and lower right corners are\n            diagonally adjacent. All other elements are set to 0.\n\n    Example::\n\n        >>> import torch\n        >>> A = torch.tensor([[0, 1], [1, 0]])\n        >>> B = torch.tensor([[3, 4, 5], [6, 7, 8]])\n        >>> C = torch.tensor(7)\n        >>> D = torch.tensor([1, 2, 3])\n        >>> E = torch.tensor([[4], [5], [6]])\n        >>> torch.block_diag(A, B, C, D, E)\n        tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],\n                [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n                [0, 0, 3, 4, 5, 0, 0, 0, 0, 0],\n                [0, 0, 6, 7, 8, 0, 0, 0, 0, 0],\n                [0, 0, 0, 0, 0, 7, 0, 0, 0, 0],\n                [0, 0, 0, 0, 0, 0, 1, 2, 3, 0],\n                [0, 0, 0, 0, 0, 0, 0, 0, 0, 4],\n                [0, 0, 0, 0, 0, 0, 0, 0, 0, 5],\n                [0, 0, 0, 0, 0, 0, 0, 0, 0, 6]])\n    \"\"\"\n    if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):\n        return handle_torch_function(block_diag, tensors, *tensors)\n    return torch._C._VariableFunctions.block_diag(tensors)\n\n\ndef cdist(x1, x2, p=2., compute_mode='use_mm_for_euclid_dist_if_necessary'):\n    # type: (Tensor, Tensor, float, str) -> (Tensor)\n    r\"\"\"Computes batched the p-norm distance between each pair of the two collections of row vectors.\n\n    Args:\n        x1 (Tensor): input tensor of shape :math:`B \\times P \\times M`.\n        x2 (Tensor): input tensor of shape :math:`B \\times R \\times M`.\n        p: p value for the p-norm distance to calculate between each vector pair\n            :math:`\\in [0, \\infty]`.\n        compute_mode:\n            'use_mm_for_euclid_dist_if_necessary' - will use matrix multiplication approach to calculate\n            euclidean distance (p = 2) if P > 25 or R > 25\n            'use_mm_for_euclid_dist' - will always use matrix multiplication approach to calculate\n            euclidean distance (p = 2)\n            'donot_use_mm_for_euclid_dist' - will never use matrix multiplication approach to calculate\n            euclidean distance (p = 2)\n            Default: use_mm_for_euclid_dist_if_necessary.\n\n    If x1 has shape :math:`B \\times P \\times M` and x2 has shape :math:`B \\times R \\times M` then the\n    output will have shape :math:`B \\times P \\times R`.\n\n    This function is equivalent to `scipy.spatial.distance.cdist(input,'minkowski', p=p)`\n    if :math:`p \\in (0, \\infty)`. When :math:`p = 0` it is equivalent to\n    `scipy.spatial.distance.cdist(input, 'hamming') * M`. When :math:`p = \\infty`, the closest\n    scipy function is `scipy.spatial.distance.cdist(xn, lambda x, y: np.abs(x - y).max())`.\n\n    Example:\n\n        >>> a = torch.tensor([[0.9041,  0.0196], [-0.3108, -2.4423], [-0.4821,  1.059]])\n        >>> a\n        tensor([[ 0.9041,  0.0196],\n                [-0.3108, -2.4423],\n                [-0.4821,  1.0590]])\n        >>> b = torch.tensor([[-2.1763, -0.4713], [-0.6986,  1.3702]])\n        >>> b\n        tensor([[-2.1763, -0.4713],\n                [-0.6986,  1.3702]])\n        >>> torch.cdist(a, b, p=2)\n        tensor([[3.1193, 2.0959],\n                [2.7138, 3.8322],\n                [2.2830, 0.3791]])\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if (type(x1) is not Tensor or type(x2) is not Tensor) and has_torch_function((x1, x2)):\n            return handle_torch_function(\n                cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode)\n    if compute_mode == 'use_mm_for_euclid_dist_if_necessary':\n        return _VF.cdist(x1, x2, p, None)\n    elif compute_mode == 'use_mm_for_euclid_dist':\n        return _VF.cdist(x1, x2, p, 1)\n    elif compute_mode == 'donot_use_mm_for_euclid_dist':\n        return _VF.cdist(x1, x2, p, 2)\n    else:\n        raise ValueError(\"{} is not a valid value for compute_mode\".format(compute_mode))\n\n# TODO: type dim as BroadcastingList when https://github.com/pytorch/pytorch/issues/33782 is fixed\n@overload  # noqa: 749\ndef norm(input, p=\"fro\", dim=None, keepdim=False, out=None, dtype=None):  # noqa: 749\n    # type: (Tensor, str, Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor\n    pass\n\n@overload  # noqa: 749\ndef norm(input, p=\"fro\", dim=None, keepdim=False, out=None, dtype=None):  # noqa: 749\n    # type: (Tensor, Optional[number], Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor\n    pass\n\n@overload  # noqa: 749\ndef norm(input, p=\"fro\", dim=None, keepdim=False, out=None, dtype=None):  # noqa: 749\n    # type: (Tensor, Optional[number], Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor\n    pass\n\n@overload  # noqa: 749\ndef norm(input, p=\"fro\", dim=None, keepdim=False, out=None, dtype=None):  # noqa: 749\n    # type: (Tensor, str, Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor\n    pass\n\ndef norm(input, p=\"fro\", dim=None, keepdim=False, out=None, dtype=None):  # noqa: 749\n    r\"\"\"Returns the matrix norm or vector norm of a given tensor.\n\n    Args:\n        input (Tensor): the input tensor\n        p (int, float, inf, -inf, 'fro', 'nuc', optional): the order of norm. Default: ``'fro'``\n            The following norms can be calculated:\n\n            =====  ============================  ==========================\n            ord    matrix norm                   vector norm\n            =====  ============================  ==========================\n            None   Frobenius norm                2-norm\n            'fro'  Frobenius norm                --\n            'nuc'  nuclear norm                  --\n            Other  as vec norm when dim is None  sum(abs(x)**ord)**(1./ord)\n            =====  ============================  ==========================\n\n        dim (int, 2-tuple of ints, 2-list of ints, optional): If it is an int,\n            vector norm will be calculated, if it is 2-tuple of ints, matrix norm\n            will be calculated. If the value is None, matrix norm will be calculated\n            when the input tensor only has two dimensions, vector norm will be\n            calculated when the input tensor only has one dimension. If the input\n            tensor has more than two dimensions, the vector norm will be applied to\n            last dimension.\n        keepdim (bool, optional): whether the output tensors have :attr:`dim`\n            retained or not. Ignored if :attr:`dim` = ``None`` and\n            :attr:`out` = ``None``. Default: ``False``\n        out (Tensor, optional): the output tensor. Ignored if\n            :attr:`dim` = ``None`` and :attr:`out` = ``None``.\n        dtype (:class:`torch.dtype`, optional): the desired data type of\n            returned tensor. If specified, the input tensor is casted to\n            :attr:'dtype' while performing the operation. Default: None.\n\n\n    Example::\n\n        >>> import torch\n        >>> a = torch.arange(9, dtype= torch.float) - 4\n        >>> b = a.reshape((3, 3))\n        >>> torch.norm(a)\n        tensor(7.7460)\n        >>> torch.norm(b)\n        tensor(7.7460)\n        >>> torch.norm(a, float('inf'))\n        tensor(4.)\n        >>> torch.norm(b, float('inf'))\n        tensor(4.)\n        >>> c = torch.tensor([[ 1, 2, 3],[-1, 1, 4]] , dtype= torch.float)\n        >>> torch.norm(c, dim=0)\n        tensor([1.4142, 2.2361, 5.0000])\n        >>> torch.norm(c, dim=1)\n        tensor([3.7417, 4.2426])\n        >>> torch.norm(c, p=1, dim=1)\n        tensor([6., 6.])\n        >>> d = torch.arange(8, dtype= torch.float).reshape(2,2,2)\n        >>> torch.norm(d, dim=(1,2))\n        tensor([ 3.7417, 11.2250])\n        >>> torch.norm(d[0, :, :]), torch.norm(d[1, :, :])\n        (tensor(3.7417), tensor(11.2250))\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return handle_torch_function(\n                norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype)\n\n    ndim = input.dim()\n\n\n    # catch default case\n    if dim is None and out is None and dtype is None and p is not None:\n        if isinstance(p, str):\n            if p == \"fro\":\n                return _VF.frobenius_norm(input)\n        if not isinstance(p, str):\n            return _VF.norm(input, p)\n\n    # TODO: when https://github.com/pytorch/pytorch/issues/33782 is fixed\n    # remove the overloads where dim is an int and replace with BraodcastingList1\n    # and remove next four lines, replace _dim with dim\n    if dim is not None:\n        if isinstance(dim, int):\n            _dim = [dim]\n        else:\n            _dim = dim\n    else:\n        _dim = None\n\n    if isinstance(p, str):\n        if p == \"fro\":\n            if dtype is not None:\n                raise ValueError(\"dtype argument is not supported in frobenius norm\")\n\n            if _dim is None:\n                _dim = [i for i in range(ndim)]  # noqa: C416 TODO: rewrite as list(range(m))\n            if out is None:\n                return _VF.frobenius_norm(input, _dim, keepdim=keepdim)\n            else:\n                return _VF.frobenius_norm(input, _dim, keepdim=keepdim, out=out)\n        elif p == \"nuc\":\n            if dtype is not None:\n                raise ValueError(\"dtype argument is not supported in nuclear norm\")\n            if _dim is None:\n                if out is None:\n                    return _VF.nuclear_norm(input, keepdim=keepdim)\n                else:\n                    return _VF.nuclear_norm(input, keepdim=keepdim, out=out)\n            else:\n                if out is None:\n                    return _VF.nuclear_norm(input, _dim, keepdim=keepdim)\n                else:\n                    return _VF.nuclear_norm(input, _dim, keepdim=keepdim, out=out)\n        raise RuntimeError(\"only valid string values are 'fro' and 'nuc', found {}\".format(p))\n    else:\n        if _dim is None:\n            _dim = [i for i in range(ndim)]  # noqa: C416 TODO: rewrite as list(range(m))\n\n        if out is None:\n            if dtype is None:\n                return _VF.norm(input, p, _dim, keepdim=keepdim)\n            else:\n                return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype)\n        else:\n            if dtype is None:\n                return _VF.norm(input, p, _dim, keepdim=keepdim, out=out)\n            else:\n                return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype, out=out)\n\ndef chain_matmul(*matrices):\n    r\"\"\"Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed\n    using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms\n    of arithmetic operations (`[CLRS]`_). Note that since this is a function to compute the product, :math:`N`\n    needs to be greater than or equal to 2; if equal to 2 then a trivial matrix-matrix product is returned.\n    If :math:`N` is 1, then this is a no-op - the original matrix is returned as is.\n\n\n    Args:\n        matrices (Tensors...): a sequence of 2 or more 2-D tensors whose product is to be determined.\n\n\n    Returns:\n        Tensor: if the :math:`i^{th}` tensor was of dimensions :math:`p_{i} \\times p_{i + 1}`, then the product\n        would be of dimensions :math:`p_{1} \\times p_{N + 1}`.\n\n    Example::\n\n        >>> a = torch.randn(3, 4)\n        >>> b = torch.randn(4, 5)\n        >>> c = torch.randn(5, 6)\n        >>> d = torch.randn(6, 7)\n        >>> torch.chain_matmul(a, b, c, d)\n        tensor([[ -2.3375,  -3.9790,  -4.1119,  -6.6577,   9.5609, -11.5095,  -3.2614],\n                [ 21.4038,   3.3378,  -8.4982,  -5.2457, -10.2561,  -2.4684,   2.7163],\n                [ -0.9647,  -5.8917,  -2.3213,  -5.2284,  12.8615, -12.2816,  -2.5095]])\n\n    .. _`[CLRS]`: https://mitpress.mit.edu/books/introduction-algorithms-third-edition\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if any(type(t) is not Tensor for t in matrices) and has_torch_function(matrices):\n            return handle_torch_function(chain_matmul, matrices, *matrices)\n    return _VF.chain_matmul(matrices)\n\n\ndef _lu_impl(A, pivot=True, get_infos=False, out=None):\n    # type: (Tensor, bool, bool, Any) -> Tuple[Tensor, Tensor, Tensor]\n    r\"\"\"Computes the LU factorization of a matrix or batches of matrices\n    :attr:`A`. Returns a tuple containing the LU factorization and\n    pivots of :attr:`A`.  Pivoting is done if :attr:`pivot` is set to\n    ``True``.\n\n    .. note::\n        The pivots returned by the function are 1-indexed. If :attr:`pivot` is ``False``,\n        then the returned pivots is a tensor filled with zeros of the appropriate size.\n\n    .. note::\n        LU factorization with :attr:`pivot` = ``False`` is not available for CPU, and attempting\n        to do so will throw an error. However, LU factorization with :attr:`pivot` = ``False`` is\n        available for CUDA.\n\n    .. note::\n        This function does not check if the factorization was successful or not if\n        :attr:`get_infos` is ``True`` since the status of the factorization is present in the\n        third element of the return tuple.\n\n    .. note::\n        In the case of batches of square matrices with size less or\n        equal to 32 on a CUDA device, the LU factorization is repeated\n        for singular matrices due to the bug in the MAGMA library (see\n        magma issue 13).\n\n    .. note::\n       ``L``, ``U``, and ``P`` can be derived using :func:`torch.lu_unpack`.\n\n    Arguments:\n        A (Tensor): the tensor to factor of size :math:`(*, m, n)`\n        pivot (bool, optional): controls whether pivoting is done. Default: ``True``\n        get_infos (bool, optional): if set to ``True``, returns an info IntTensor.\n                                    Default: ``False``\n        out (tuple, optional): optional output tuple. If :attr:`get_infos` is ``True``,\n                               then the elements in the tuple are Tensor, IntTensor,\n                               and IntTensor. If :attr:`get_infos` is ``False``, then the\n                               elements in the tuple are Tensor, IntTensor. Default: ``None``\n\n    Returns:\n        (Tensor, IntTensor, IntTensor (optional)): A tuple of tensors containing\n\n            - **factorization** (*Tensor*): the factorization of size :math:`(*, m, n)`\n\n            - **pivots** (*IntTensor*): the pivots of size :math:`(*, m)`\n\n            - **infos** (*IntTensor*, *optional*): if :attr:`get_infos` is ``True``, this is a tensor of\n              size :math:`(*)` where non-zero values indicate whether factorization for the matrix or\n              each minibatch has succeeded or failed\n\n    Example::\n\n        >>> A = torch.randn(2, 3, 3)\n        >>> A_LU, pivots = torch.lu(A)\n        >>> A_LU\n        tensor([[[ 1.3506,  2.5558, -0.0816],\n                 [ 0.1684,  1.1551,  0.1940],\n                 [ 0.1193,  0.6189, -0.5497]],\n\n                [[ 0.4526,  1.2526, -0.3285],\n                 [-0.7988,  0.7175, -0.9701],\n                 [ 0.2634, -0.9255, -0.3459]]])\n        >>> pivots\n        tensor([[ 3,  3,  3],\n                [ 3,  3,  3]], dtype=torch.int32)\n        >>> A_LU, pivots, info = torch.lu(A, get_infos=True)\n        >>> if info.nonzero().size(0) == 0:\n        ...   print('LU factorization succeeded for all samples!')\n        LU factorization succeeded for all samples!\n    \"\"\"\n    # If get_infos is True, then we don't need to check for errors and vice versa\n    return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))\n\ndef _check_list_size(out_len, get_infos, out):\n    # type: (int, bool, List[Tensor]) -> None\n    get_infos_int = 1 if get_infos else 0\n    if out_len - get_infos_int != 2:\n        raise TypeError(\"expected tuple of {} elements but got {}\"\n                        .format(2 + int(get_infos), len(out_len)))\n    if not isinstance(out, (tuple, list)):\n        raise TypeError(\"argument 'out' must be tuple of Tensors, not {}\"\n                        .format(type(out).__name__))\n\ndef _lu_with_infos(A, pivot=True, get_infos=False, out=None):\n    # type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor]\n    if not torch.jit.is_scripting():\n        if type(A) is not Tensor and has_torch_function((A,)):\n            return handle_torch_function(\n                lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)\n    result = _lu_impl(A, pivot, get_infos, out)\n    if out is not None:\n        _check_list_size(len(out), get_infos, out)\n        for i in range(len(out)):\n            out[i].resize_as_(result[i]).copy_(result[i])\n        return out\n    else:\n        return result  # A_LU, pivots, infos\n\ndef _lu_no_infos(A, pivot=True, get_infos=False, out=None):\n    # type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]\n    # need to check for torch_function here so that we exit if\n    if not torch.jit.is_scripting():\n        if type(A) is not Tensor and has_torch_function((A,)):\n            return handle_torch_function(\n                lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)\n    result = _lu_impl(A, pivot, get_infos, out)\n    if out is not None:\n        _check_list_size(len(out), get_infos, out)\n        for i in range(len(out)):\n            out[i].resize_as_(result[i]).copy_(result[i])\n        return out\n    else:\n        return result[0], result[1]  # A_LU, pivots\n\n# The return type of lu depends on `get_infos`, so in order to resolve the output type\n# of lu in TorchScript we need to statically know the value of `get_infos`\nlu = boolean_dispatch(\n    arg_name='get_infos',\n    arg_index=2,\n    default=False,\n    if_true=_lu_with_infos,\n    if_false=_lu_no_infos,\n    module_name=__name__,\n    func_name='lu')\nlu.__doc__ = _lu_impl.__doc__\n\ndef align_tensors(*tensors):\n    raise RuntimeError('`align_tensors` not yet implemented.')\n"
  },
  {
    "path": "patches/pytorch/1.6.0/functional.py",
    "content": "from typing import Tuple, Optional\n\nimport librosa  # STFT patch for aarch64\nimport numpy as np\n\nimport torch\nimport torch.nn.functional as F\nfrom ._lowrank import svd_lowrank, pca_lowrank\nfrom ._overrides import has_torch_function, handle_torch_function\nfrom ._jit_internal import boolean_dispatch, List\nfrom ._jit_internal import _overload as overload\n\nTensor = torch.Tensor\nfrom torch import _VF\n\n__all__ = [\n    'align_tensors',\n    'broadcast_tensors',\n    'cartesian_prod',\n    'block_diag',\n    'cdist',\n    'chain_matmul',\n    'einsum',\n    'istft',\n    'lu',\n    'lu_unpack',\n    'norm',\n    'meshgrid',\n    'pca_lowrank',\n    'split',\n    'stft',\n    'svd_lowrank',\n    'tensordot',\n    'unique',\n    'unique_consecutive',\n]\n\n\ndef broadcast_tensors(*tensors):\n    r\"\"\"broadcast_tensors(*tensors) -> List of Tensors\n\n    Broadcasts the given tensors according to :ref:`broadcasting-semantics`.\n\n    Args:\n        *tensors: any number of tensors of the same type\n\n    .. warning::\n\n        More than one element of a broadcasted tensor may refer to a single\n        memory location. As a result, in-place operations (especially ones that\n        are vectorized) may result in incorrect behavior. If you need to write\n        to the tensors, please clone them first.\n\n    Example::\n\n        >>> x = torch.arange(3).view(1, 3)\n        >>> y = torch.arange(2).view(2, 1)\n        >>> a, b = torch.broadcast_tensors(x, y)\n        >>> a.size()\n        torch.Size([2, 3])\n        >>> a\n        tensor([[0, 1, 2],\n                [0, 1, 2]])\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):\n            return handle_torch_function(broadcast_tensors, tensors, *tensors)\n    return _VF.broadcast_tensors(tensors)\n\n\ndef split(tensor, split_size_or_sections, dim=0):\n    r\"\"\"Splits the tensor into chunks. Each chunk is a view of the original tensor.\n\n    If :attr:`split_size_or_sections` is an integer type, then :attr:`tensor` will\n    be split into equally sized chunks (if possible). Last chunk will be smaller if\n    the tensor size along the given dimension :attr:`dim` is not divisible by\n    :attr:`split_size`.\n\n    If :attr:`split_size_or_sections` is a list, then :attr:`tensor` will be split\n    into ``len(split_size_or_sections)`` chunks with sizes in :attr:`dim` according\n    to :attr:`split_size_or_sections`.\n\n    Arguments:\n        tensor (Tensor): tensor to split.\n        split_size_or_sections (int) or (list(int)): size of a single chunk or\n            list of sizes for each chunk\n        dim (int): dimension along which to split the tensor.\n\n    Example::\n        >>> a = torch.arange(10).reshape(5,2)\n        >>> a\n        tensor([[0, 1],\n                [2, 3],\n                [4, 5],\n                [6, 7],\n                [8, 9]])\n        >>> torch.split(a, 2)\n        (tensor([[0, 1],\n                 [2, 3]]),\n         tensor([[4, 5],\n                 [6, 7]]),\n         tensor([[8, 9]]))\n        >>> torch.split(a, [1,4])\n        (tensor([[0, 1]]),\n         tensor([[2, 3],\n                 [4, 5],\n                 [6, 7],\n                 [8, 9]]))\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if type(tensor) is not Tensor and has_torch_function((tensor,)):\n            return handle_torch_function(split, (tensor,), tensor, split_size_or_sections,\n                                         dim=dim)\n    # Overwriting reason:\n    # This dispatches to two ATen functions depending on the type of\n    # split_size_or_sections. The branching code is in tensor.py, which we\n    # call here.\n    return tensor.split(split_size_or_sections, dim)\n\n# equivalent to itertools.product(indices)\ndef _indices_product(indices):\n    # type: (List[int]) -> (List[List[int]])\n    empty_list = torch.jit.annotate(List[int], [])\n    result = [empty_list]\n    for idx in indices:\n        result_temp = torch.jit.annotate(List[List[int]], [])\n        for res in result:\n            for i in range(idx):\n                result_temp.append(res + [i])\n        result = result_temp\n    return result\n\ndef _index_tensor_with_indices_list(tensor, indices):\n    # type: (Tensor, List[int]) -> Tensor\n    out = tensor\n    for index in indices:\n        out = out[index]\n    return out\n\ndef lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):\n    # type: (Tensor, Tensor, bool, bool) ->  (Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]])\n    r\"\"\"Unpacks the data and pivots from a LU factorization of a tensor.\n\n    Returns a tuple of tensors as ``(the pivots, the L tensor, the U tensor)``.\n\n    Arguments:\n        LU_data (Tensor): the packed LU factorization data\n        LU_pivots (Tensor): the packed LU factorization pivots\n        unpack_data (bool): flag indicating if the data should be unpacked\n        unpack_pivots (bool): flag indicating if the pivots should be unpacked\n\n    Examples::\n\n        >>> A = torch.randn(2, 3, 3)\n        >>> A_LU, pivots = A.lu()\n        >>> P, A_L, A_U = torch.lu_unpack(A_LU, pivots)\n        >>>\n        >>> # can recover A from factorization\n        >>> A_ = torch.bmm(P, torch.bmm(A_L, A_U))\n\n        >>> # LU factorization of a rectangular matrix:\n        >>> A = torch.randn(2, 3, 2)\n        >>> A_LU, pivots = A.lu()\n        >>> P, A_L, A_U = torch.lu_unpack(A_LU, pivots)\n        >>> P\n        tensor([[[1., 0., 0.],\n                 [0., 1., 0.],\n                 [0., 0., 1.]],\n\n                [[0., 0., 1.],\n                 [0., 1., 0.],\n                 [1., 0., 0.]]])\n        >>> A_L\n        tensor([[[ 1.0000,  0.0000],\n                 [ 0.4763,  1.0000],\n                 [ 0.3683,  0.1135]],\n\n                [[ 1.0000,  0.0000],\n                 [ 0.2957,  1.0000],\n                 [-0.9668, -0.3335]]])\n        >>> A_U\n        tensor([[[ 2.1962,  1.0881],\n                 [ 0.0000, -0.8681]],\n\n                [[-1.0947,  0.3736],\n                 [ 0.0000,  0.5718]]])\n        >>> A_ = torch.bmm(P, torch.bmm(A_L, A_U))\n        >>> torch.norm(A_ - A)\n        tensor(2.9802e-08)\n    \"\"\"\n    if not torch.jit.is_scripting():\n        tens_ops = (LU_data, LU_pivots)\n        if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):\n            return handle_torch_function(\n                lu_unpack, tens_ops, LU_data, LU_pivots, unpack_data=unpack_data,\n                unpack_pivots=unpack_pivots)\n    shape = LU_data.shape\n    # In generalized LU factorization, the following shape relations hold:\n    #   A.shape[-2:] == (m, n)\n    #   P.shape[-2:] == (m, m)\n    #   L.shape[-2:] == (m, k)\n    #   U.shape[-2:] == (k, n)\n    # where k = min(m, n)\n    m, n = shape[-2:]\n    k = min(m, n)\n    if unpack_data:\n        U = LU_data.triu()\n        if m != k:\n            U = U.narrow(-2, 0, k)\n        L = LU_data.tril()\n        if k != n:\n            L = L.narrow(-1, 0, k)\n        L.diagonal(dim1=-2, dim2=-1).fill_(1)\n    else:\n        L = U = None\n\n    if unpack_pivots:\n        LU_pivots_zero_idx = LU_pivots - 1\n        if LU_data.dim() > 2:\n            P = torch.eye(m, device=LU_data.device, dtype=LU_data.dtype) \\\n                     .expand(shape[:-1] + (m,)) \\\n                     .clone(memory_format=torch.contiguous_format)\n\n            # TODO: rewrite when TorchScript supports product and map as\n            # product(*map(lambda x: list(range(x)), shape[:-2])) when issue 33781 is fixed\n            indices = _indices_product(shape[:-2])\n            for idx in indices:\n                final_order = [i for i in range(m)]  # noqa: C416 TODO: rewrite as list(range(m))\n                for k, j in enumerate(_index_tensor_with_indices_list(LU_pivots_zero_idx, idx)):\n                    final_order[k], final_order[j] = final_order[j], final_order[k]\n                # TODO: remove _index_tensor_with_indices_list when TorchScript supports indexing Tensor with list\n                p_idx = _index_tensor_with_indices_list(P, idx)\n                p_idx.copy_(p_idx.index_select(1, torch.as_tensor(final_order, device=LU_pivots.device)))\n        else:\n            P = torch.eye(m, device=LU_data.device, dtype=LU_data.dtype)\n            final_order = [i for i in range(m)]  # noqa: C416 TODO: rewrite as list(range(m))\n            for k, j, in enumerate(LU_pivots_zero_idx):\n                final_order[k], final_order[j] = final_order[j], final_order[k]\n            P = P.index_select(1, torch.as_tensor(final_order, device=LU_pivots.device))\n    else:\n        P = None\n\n    return P, L, U\n\n\ndef einsum(equation, *operands):\n    r\"\"\"einsum(equation, *operands) -> Tensor\n\nThis function provides a way of computing multilinear expressions (i.e. sums of products) using the\nEinstein summation convention.\n\nArgs:\n    equation (string): The equation is given in terms of lower case letters (indices) to be associated\n           with each dimension of the operands and result. The left hand side lists the operands\n           dimensions, separated by commas. There should be one index letter per tensor dimension.\n           The right hand side follows after `->` and gives the indices for the output.\n           If the `->` and right hand side are omitted, it implicitly defined as the alphabetically\n           sorted list of all indices appearing exactly once in the left hand side.\n           The indices not apprearing in the output are summed over after multiplying the operands\n           entries.\n           If an index appears several times for the same operand, a diagonal is taken.\n           Ellipses `...` represent a fixed number of dimensions. If the right hand side is inferred,\n           the ellipsis dimensions are at the beginning of the output.\n    operands (Tensor): The operands to compute the Einstein sum of.\n\n.. note::\n\n    This function does not optimize the given expression, so a different formula for the same computation may\n    run faster or consume less memory. Projects like opt_einsum (https://optimized-einsum.readthedocs.io/en/stable/)\n    can optimize the formula for you.\n\nExamples::\n\n    >>> x = torch.randn(5)\n    >>> y = torch.randn(4)\n    >>> torch.einsum('i,j->ij', x, y)  # outer product\n    tensor([[-0.0570, -0.0286, -0.0231,  0.0197],\n            [ 1.2616,  0.6335,  0.5113, -0.4351],\n            [ 1.4452,  0.7257,  0.5857, -0.4984],\n            [-0.4647, -0.2333, -0.1883,  0.1603],\n            [-1.1130, -0.5588, -0.4510,  0.3838]])\n\n\n    >>> A = torch.randn(3,5,4)\n    >>> l = torch.randn(2,5)\n    >>> r = torch.randn(2,4)\n    >>> torch.einsum('bn,anm,bm->ba', l, A, r) # compare torch.nn.functional.bilinear\n    tensor([[-0.3430, -5.2405,  0.4494],\n            [ 0.3311,  5.5201, -3.0356]])\n\n\n    >>> As = torch.randn(3,2,5)\n    >>> Bs = torch.randn(3,5,4)\n    >>> torch.einsum('bij,bjk->bik', As, Bs) # batch matrix multiplication\n    tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],\n             [-1.6706, -0.8097, -0.8025, -2.1183]],\n\n            [[ 4.2239,  0.3107, -0.5756, -0.2354],\n             [-1.4558, -0.3460,  1.5087, -0.8530]],\n\n            [[ 2.8153,  1.8787, -4.3839, -1.2112],\n             [ 0.3728, -2.1131,  0.0921,  0.8305]]])\n\n    >>> A = torch.randn(3, 3)\n    >>> torch.einsum('ii->i', A) # diagonal\n    tensor([-0.7825,  0.8291, -0.1936])\n\n    >>> A = torch.randn(4, 3, 3)\n    >>> torch.einsum('...ii->...i', A) # batch diagonal\n    tensor([[-1.0864,  0.7292,  0.0569],\n            [-0.9725, -1.0270,  0.6493],\n            [ 0.5832, -1.1716, -1.5084],\n            [ 0.4041, -1.1690,  0.8570]])\n\n    >>> A = torch.randn(2, 3, 4, 5)\n    >>> torch.einsum('...ij->...ji', A).shape # batch permute\n    torch.Size([2, 3, 5, 4])\n\"\"\"\n    if not torch.jit.is_scripting():\n        if any(type(t) is not Tensor for t in operands) and has_torch_function(operands):\n            return handle_torch_function(einsum, operands, equation, *operands)\n\n    if len(operands) == 1 and isinstance(operands[0], (list, tuple)):\n        # the old interface of passing the operands as one list argument\n        operands = operands[0]\n        # recurse incase operands contains value that has torch function\n        # in the original implementation this line is omitted\n        return einsum(equation, *operands)\n\n    return _VF.einsum(equation, operands)\n\n\ndef meshgrid(*tensors):\n    r\"\"\"Take :math:`N` tensors, each of which can be either scalar or 1-dimensional\nvector, and create :math:`N` N-dimensional grids, where the :math:`i` :sup:`th` grid is defined by\nexpanding the :math:`i` :sup:`th` input over dimensions defined by other inputs.\n\n\n    Args:\n        tensors (list of Tensor): list of scalars or 1 dimensional tensors. Scalars will be\n        treated as tensors of size :math:`(1,)` automatically\n\n    Returns:\n        seq (sequence of Tensors): If the input has :math:`k` tensors of size\n        :math:`(N_1,), (N_2,), \\ldots , (N_k,)`, then the output would also have :math:`k` tensors,\n        where all tensors are of size :math:`(N_1, N_2, \\ldots , N_k)`.\n\n    Example::\n\n        >>> x = torch.tensor([1, 2, 3])\n        >>> y = torch.tensor([4, 5, 6])\n        >>> grid_x, grid_y = torch.meshgrid(x, y)\n        >>> grid_x\n        tensor([[1, 1, 1],\n                [2, 2, 2],\n                [3, 3, 3]])\n        >>> grid_y\n        tensor([[4, 5, 6],\n                [4, 5, 6],\n                [4, 5, 6]])\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):\n            return handle_torch_function(meshgrid, tensors, *tensors)\n    if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)):\n        # the old interface of passing the operands as one list argument\n        tensors = tensors[0]\n    return _VF.meshgrid(tensors)\n\n\ndef stft(input, n_fft, hop_length=None, win_length=None, window=None,\n         center=True, pad_mode='reflect', normalized=False, onesided=True):\n    # type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, str, bool, bool) -> Tensor\n    r\"\"\"Short-time Fourier transform (STFT).\n\n    Ignoring the optional batch dimension, this method computes the following\n    expression:\n\n    .. math::\n        X[m, \\omega] = \\sum_{k = 0}^{\\text{win\\_length-1}}%\n                            \\text{window}[k]\\ \\text{input}[m \\times \\text{hop\\_length} + k]\\ %\n                            \\exp\\left(- j \\frac{2 \\pi \\cdot \\omega k}{\\text{win\\_length}}\\right),\n\n    where :math:`m` is the index of the sliding window, and :math:`\\omega` is\n    the frequency that :math:`0 \\leq \\omega < \\text{n\\_fft}`. When\n    :attr:`onesided` is the default value ``True``,\n\n    * :attr:`input` must be either a 1-D time sequence or a 2-D batch of time\n      sequences.\n\n    * If :attr:`hop_length` is ``None`` (default), it is treated as equal to\n      ``floor(n_fft / 4)``.\n\n    * If :attr:`win_length` is ``None`` (default), it is treated as equal to\n      :attr:`n_fft`.\n\n    * :attr:`window` can be a 1-D tensor of size :attr:`win_length`, e.g., from\n      :meth:`torch.hann_window`. If :attr:`window` is ``None`` (default), it is\n      treated as if having :math:`1` everywhere in the window. If\n      :math:`\\text{win\\_length} < \\text{n\\_fft}`, :attr:`window` will be padded on\n      both sides to length :attr:`n_fft` before being applied.\n\n    * If :attr:`center` is ``True`` (default), :attr:`input` will be padded on\n      both sides so that the :math:`t`-th frame is centered at time\n      :math:`t \\times \\text{hop\\_length}`. Otherwise, the :math:`t`-th frame\n      begins at time  :math:`t \\times \\text{hop\\_length}`.\n\n    * :attr:`pad_mode` determines the padding method used on :attr:`input` when\n      :attr:`center` is ``True``. See :meth:`torch.nn.functional.pad` for\n      all available options. Default is ``\"reflect\"``.\n\n    * If :attr:`onesided` is ``True`` (default), only values for :math:`\\omega`\n      in :math:`\\left[0, 1, 2, \\dots, \\left\\lfloor \\frac{\\text{n\\_fft}}{2} \\right\\rfloor + 1\\right]`\n      are returned because the real-to-complex Fourier transform satisfies the\n      conjugate symmetry, i.e., :math:`X[m, \\omega] = X[m, \\text{n\\_fft} - \\omega]^*`.\n\n    * If :attr:`normalized` is ``True`` (default is ``False``), the function\n      returns the normalized STFT results, i.e., multiplied by :math:`(\\text{frame\\_length})^{-0.5}`.\n\n    Returns the real and the imaginary parts together as one tensor of size\n    :math:`(* \\times N \\times T \\times 2)`, where :math:`*` is the optional\n    batch size of :attr:`input`, :math:`N` is the number of frequencies where\n    STFT is applied, :math:`T` is the total number of frames used, and each pair\n    in the last dimension represents a complex number as the real part and the\n    imaginary part.\n\n    .. warning::\n      This function changed signature at version 0.4.1. Calling with the\n      previous signature may cause error or return incorrect result.\n\n    Arguments:\n        input (Tensor): the input tensor\n        n_fft (int): size of Fourier transform\n        hop_length (int, optional): the distance between neighboring sliding window\n            frames. Default: ``None`` (treated as equal to ``floor(n_fft / 4)``)\n        win_length (int, optional): the size of window frame and STFT filter.\n            Default: ``None``  (treated as equal to :attr:`n_fft`)\n        window (Tensor, optional): the optional window function.\n            Default: ``None`` (treated as window of all :math:`1` s)\n        center (bool, optional): whether to pad :attr:`input` on both sides so\n            that the :math:`t`-th frame is centered at time :math:`t \\times \\text{hop\\_length}`.\n            Default: ``True``\n        pad_mode (string, optional): controls the padding method used when\n            :attr:`center` is ``True``. Default: ``\"reflect\"``\n        normalized (bool, optional): controls whether to return the normalized STFT results\n             Default: ``False``\n        onesided (bool, optional): controls whether to return half of results to\n            avoid redundancy Default: ``True``\n\n    Returns:\n        Tensor: A tensor containing the STFT result with shape described above\n\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return handle_torch_function(\n                stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length,\n                window=window, center=center, pad_mode=pad_mode, normalized=normalized,\n                onesided=onesided)\n    # TODO: after having proper ways to map Python strings to ATen Enum, move\n    #       this and F.pad to ATen.\n    if center:\n        signal_dim = input.dim()\n        extended_shape = [1] * (3 - signal_dim) + list(input.size())\n        pad = int(n_fft // 2)\n        input = F.pad(input.view(extended_shape), (pad, pad), pad_mode)\n        input = input.view(input.shape[-signal_dim:])\n        \n    # STFT patch for aarch64\n    # https://stackoverflow.com/a/66872148\n    librosa_stft = librosa.stft(input.cpu().detach().numpy().reshape(-1), n_fft, hop_length, win_length, window=\"hann\", center=center, pad_mode=pad_mode)\n    librosa_stft = np.array([[a.real, a.imag] for a in librosa_stft])\n    librosa_stft = np.transpose(librosa_stft, axes=[0, 2, 1])\n    librosa_stft = np.expand_dims(librosa_stft, 0)\n    librosa_stft = torch.from_numpy(librosa_stft)\n    return librosa_stft\n    #return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore\n    #                normalized, onesided, return_complex)\n\n\ndef istft(input, n_fft, hop_length=None, win_length=None, window=None,\n          center=True, normalized=False, onesided=True, length=None):\n    # type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, bool, bool, Optional[int]) -> Tensor\n    r\"\"\"Inverse short time Fourier Transform. This is expected to be the inverse of :func:`~torch.stft`.\n    It has the same parameters (+ additional optional parameter of :attr:`length`) and it should return the\n    least squares estimation of the original signal. The algorithm will check using the NOLA condition (\n    nonzero overlap).\n\n    Important consideration in the parameters :attr:`window` and :attr:`center` so that the envelop\n    created by the summation of all the windows is never zero at certain point in time. Specifically,\n    :math:`\\sum_{t=-\\infty}^{\\infty} w^2[n-t\\times hop\\_length] \\cancel{=} 0`.\n\n    Since :func:`~torch.stft` discards elements at the end of the signal if they do not fit in a frame,\n    ``istft`` may return a shorter signal than the original signal (can occur if :attr:`center` is False\n    since the signal isn't padded).\n\n    If :attr:`center` is ``True``, then there will be padding e.g. ``'constant'``, ``'reflect'``, etc.\n    Left padding can be trimmed off exactly because they can be calculated but right padding cannot be\n    calculated without additional information.\n\n    Example: Suppose the last window is:\n    ``[17, 18, 0, 0, 0]`` vs ``[18, 0, 0, 0, 0]``\n\n    The :attr:`n_fft`, :attr:`hop_length`, :attr:`win_length` are all the same which prevents the calculation\n    of right padding. These additional values could be zeros or a reflection of the signal so providing\n    :attr:`length` could be useful. If :attr:`length` is ``None`` then padding will be aggressively removed\n    (some loss of signal).\n\n    [1] D. W. Griffin and J. S. Lim, \"Signal estimation from modified short-time Fourier transform,\"\n    IEEE Trans. ASSP, vol.32, no.2, pp.236-243, Apr. 1984.\n\n    Arguments:\n        input (Tensor): The input tensor. Expected to be output of :func:`~torch.stft`,\n            either 3D (``fft_size``, ``n_frame``, 2) or 4D (``channel``, ``fft_size``, ``n_frame``, 2).\n        n_fft (int): Size of Fourier transform\n        hop_length (Optional[int]): The distance between neighboring sliding window frames.\n            (Default: ``n_fft // 4``)\n        win_length (Optional[int]): The size of window frame and STFT filter. (Default: ``n_fft``)\n        window (Optional[torch.Tensor]): The optional window function.\n            (Default: ``torch.ones(win_length)``)\n        center (bool): Whether :attr:`input` was padded on both sides so that the :math:`t`-th frame is\n            centered at time :math:`t \\times \\text{hop\\_length}`.\n            (Default: ``True``)\n        normalized (bool): Whether the STFT was normalized. (Default: ``False``)\n        onesided (bool): Whether the STFT is onesided. (Default: ``True``)\n        length (Optional[int]): The amount to trim the signal by (i.e. the\n            original signal length). (Default: whole signal)\n\n    Returns:\n        Tensor: Least squares estimation of the original signal of size (..., signal_length)\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return handle_torch_function(\n                istft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length,\n                window=window, center=center, normalized=normalized, onesided=onesided,\n                length=length)\n\n    return _VF.istft(\n        input, n_fft, hop_length, win_length, window, center, normalized, onesided, length)\n\n\ndel torch.unique_dim\n\n\ndef _unique_impl(input, sorted=True, return_inverse=False, return_counts=False, dim=None):\n    # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor, Tensor]\n    r\"\"\"Returns the unique elements of the input tensor.\n\n    .. note:: This function is different from :func:`torch.unique_consecutive` in the sense that\n        this function also eliminates non-consecutive duplicate values.\n\n    .. note:: Currently in the CUDA implementation and the CPU implementation when dim is specified,\n        `torch.unique` always sort the tensor at the beginning regardless of the `sort` argument.\n        Sorting could be slow, so if your input tensor is already sorted, it is recommended to use\n        :func:`torch.unique_consecutive` which avoids the sorting.\n\n    Arguments:\n        input (Tensor): the input tensor\n        sorted (bool): Whether to sort the unique elements in ascending order\n            before returning as output.\n        return_inverse (bool): Whether to also return the indices for where\n            elements in the original input ended up in the returned unique list.\n        return_counts (bool): Whether to also return the counts for each unique\n            element.\n        dim (int): the dimension to apply unique. If ``None``, the unique of the\n            flattened input is returned. default: ``None``\n\n    Returns:\n        (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing\n\n            - **output** (*Tensor*): the output list of unique scalar elements.\n            - **inverse_indices** (*Tensor*): (optional) if\n              :attr:`return_inverse` is True, there will be an additional\n              returned tensor (same shape as input) representing the indices\n              for where elements in the original input map to in the output;\n              otherwise, this function will only return a single tensor.\n            - **counts** (*Tensor*): (optional) if\n              :attr:`return_counts` is True, there will be an additional\n              returned tensor (same shape as output or output.size(dim),\n              if dim was specified) representing the number of occurrences\n              for each unique value or tensor.\n\n    Example::\n\n        >>> output = torch.unique(torch.tensor([1, 3, 2, 3], dtype=torch.long))\n        >>> output\n        tensor([ 2,  3,  1])\n\n        >>> output, inverse_indices = torch.unique(\n                torch.tensor([1, 3, 2, 3], dtype=torch.long), sorted=True, return_inverse=True)\n        >>> output\n        tensor([ 1,  2,  3])\n        >>> inverse_indices\n        tensor([ 0,  2,  1,  2])\n\n        >>> output, inverse_indices = torch.unique(\n                torch.tensor([[1, 3], [2, 3]], dtype=torch.long), sorted=True, return_inverse=True)\n        >>> output\n        tensor([ 1,  2,  3])\n        >>> inverse_indices\n        tensor([[ 0,  2],\n                [ 1,  2]])\n\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return handle_torch_function(\n                unique, (input,), input, sorted=sorted, return_inverse=return_inverse,\n                return_counts=return_counts, dim=dim)\n\n    if dim is not None:\n        output, inverse_indices, counts = _VF.unique_dim(\n            input,\n            dim,\n            sorted=sorted,\n            return_inverse=return_inverse,\n            return_counts=return_counts,\n        )\n    else:\n        output, inverse_indices, counts = torch._unique2(\n            input,\n            sorted=sorted,\n            return_inverse=return_inverse,\n            return_counts=return_counts,\n        )\n    return output, inverse_indices, counts\n\n\ndef _unique_consecutive_impl(input, return_inverse=False, return_counts=False, dim=None):\n    # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor, Tensor]\n    r\"\"\"Eliminates all but the first element from every consecutive group of equivalent elements.\n\n    .. note:: This function is different from :func:`torch.unique` in the sense that this function\n        only eliminates consecutive duplicate values. This semantics is similar to `std::unique`\n        in C++.\n\n    Arguments:\n        input (Tensor): the input tensor\n        return_inverse (bool): Whether to also return the indices for where\n            elements in the original input ended up in the returned unique list.\n        return_counts (bool): Whether to also return the counts for each unique\n            element.\n        dim (int): the dimension to apply unique. If ``None``, the unique of the\n            flattened input is returned. default: ``None``\n\n    Returns:\n        (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing\n\n            - **output** (*Tensor*): the output list of unique scalar elements.\n            - **inverse_indices** (*Tensor*): (optional) if\n              :attr:`return_inverse` is True, there will be an additional\n              returned tensor (same shape as input) representing the indices\n              for where elements in the original input map to in the output;\n              otherwise, this function will only return a single tensor.\n            - **counts** (*Tensor*): (optional) if\n              :attr:`return_counts` is True, there will be an additional\n              returned tensor (same shape as output or output.size(dim),\n              if dim was specified) representing the number of occurrences\n              for each unique value or tensor.\n\n    Example::\n\n        >>> x = torch.tensor([1, 1, 2, 2, 3, 1, 1, 2])\n        >>> output = torch.unique_consecutive(x)\n        >>> output\n        tensor([1, 2, 3, 1, 2])\n\n        >>> output, inverse_indices = torch.unique_consecutive(x, return_inverse=True)\n        >>> output\n        tensor([1, 2, 3, 1, 2])\n        >>> inverse_indices\n        tensor([0, 0, 1, 1, 2, 3, 3, 4])\n\n        >>> output, counts = torch.unique_consecutive(x, return_counts=True)\n        >>> output\n        tensor([1, 2, 3, 1, 2])\n        >>> counts\n        tensor([2, 2, 1, 2, 1])\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return handle_torch_function(\n                unique_consecutive, (input,), input, return_inverse=return_inverse,\n                return_counts=return_counts, dim=dim)\n    output, inverse_indices, counts = _VF.unique_consecutive(\n        input, return_inverse=return_inverse, return_counts=return_counts, dim=dim)\n    return output, inverse_indices, counts\n\n\ndef _return_counts(input, sorted=True, return_inverse=False, return_counts=False, dim=None):\n    # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]\n\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return _unique_impl(input, sorted, return_inverse, return_counts, dim)\n\n    output, _, counts = _unique_impl(input, sorted, return_inverse, return_counts, dim)\n    return output, counts\n\ndef _return_output(input, sorted=True, return_inverse=False, return_counts=False, dim=None):\n    # type: (Tensor, bool, bool, bool, Optional[int]) -> Tensor\n\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return _unique_impl(input, sorted, return_inverse, return_counts, dim)\n\n    output, _, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)\n    return output\n\ndef _return_inverse(input, sorted=True, return_inverse=False, return_counts=False, dim=None):\n    # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]\n\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return _unique_impl(input, sorted, return_inverse, return_counts, dim)\n\n    output, inverse_indices, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)\n    return output, inverse_indices\n\n_return_inverse_false = boolean_dispatch(\n    arg_name='return_counts',\n    arg_index=3,\n    default=False,\n    if_true=_return_counts,\n    if_false=_return_output,\n    module_name=__name__,\n    func_name='unique')\n\n_return_inverse_true = boolean_dispatch(\n    arg_name='return_counts',\n    arg_index=3,\n    default=False,\n    if_true=_unique_impl,\n    if_false=_return_inverse,\n    module_name=__name__,\n    func_name='unique')\n\n# The return type of unique depends on `return_inverse`, and `return_counts` so in order to\n# resolve the output type in TorchScript we need to statically know the value of both parameters\n\nunique = boolean_dispatch(\n    arg_name='return_inverse',\n    arg_index=2,\n    default=False,\n    if_true=_return_inverse_true,\n    if_false=_return_inverse_false,\n    module_name=__name__,\n    func_name='unique')\nunique.__doc__ = _unique_impl.__doc__\n\n\ndef _consecutive_return_counts(input, return_inverse=False, return_counts=False, dim=None):\n    # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]\n\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return _unique_consecutive_impl(input, return_inverse, return_counts, dim)\n\n    output, _, counts = _unique_consecutive_impl(input, return_inverse, return_counts, dim)\n    return output, counts\n\ndef _consecutive_return_output(input, return_inverse=False, return_counts=False, dim=None):\n    # type: (Tensor, bool, bool, Optional[int]) -> Tensor\n\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return _unique_consecutive_impl(input, return_inverse, return_counts, dim)\n\n    output, _, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)\n    return output\n\ndef _consecutive_return_inverse(input, return_inverse=False, return_counts=False, dim=None):\n    # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]\n\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return _unique_consecutive_impl(input, return_inverse, return_counts, dim)\n\n    output, inverse_indices, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)\n    return output, inverse_indices\n\n_consecutive_return_inverse_false = boolean_dispatch(\n    arg_name='return_counts',\n    arg_index=1,\n    default=False,\n    if_true=_consecutive_return_counts,\n    if_false=_consecutive_return_output,\n    module_name=__name__,\n    func_name='unique_consecutive')\n\n_consecutive_return_inverse_true = boolean_dispatch(\n    arg_name='return_counts',\n    arg_index=1,\n    default=False,\n    if_true=_unique_consecutive_impl,\n    if_false=_consecutive_return_inverse,\n    module_name=__name__,\n    func_name='unique_consecutive')\n\n# The return type of unique depends on `return_inverse`, and `return_counts` so in order to\n# resolve the output type in TorchScript we need to statically know the value of both parameters\n\nunique_consecutive = boolean_dispatch(\n    arg_name='return_inverse',\n    arg_index=2,\n    default=False,\n    if_true=_consecutive_return_inverse_true,\n    if_false=_consecutive_return_inverse_false,\n    module_name=__name__,\n    func_name='unique_consecutive')\nunique_consecutive.__doc__ = _unique_consecutive_impl.__doc__\n\n\ndef tensordot(a, b, dims=2):\n    r\"\"\"Returns a contraction of a and b over multiple dimensions.\n\n    :attr:`tensordot` implements a generalized matrix product.\n\n    Args:\n      a (Tensor): Left tensor to contract\n      b (Tensor): Right tensor to contract\n      dims (int or tuple of two lists of integers): number of dimensions to\n         contract or explicit lists of dimensions for :attr:`a` and\n         :attr:`b` respectively\n\n    When called with a non-negative integer argument :attr:`dims` = :math:`d`, and\n    the number of dimensions of :attr:`a` and :attr:`b` is :math:`m` and :math:`n`,\n    respectively, :func:`~torch.tensordot` computes\n\n    .. math::\n        r_{i_0,...,i_{m-d}, i_d,...,i_n}\n          = \\sum_{k_0,...,k_{d-1}} a_{i_0,...,i_{m-d},k_0,...,k_{d-1}} \\times b_{k_0,...,k_{d-1}, i_d,...,i_n}.\n\n    When called with :attr:`dims` of the list form, the given dimensions will be contracted\n    in place of the last :math:`d` of :attr:`a` and the first :math:`d` of :math:`b`. The sizes\n    in these dimensions must match, but :func:`~torch.tensordot` will deal with broadcasted\n    dimensions.\n\n    Examples::\n\n        >>> a = torch.arange(60.).reshape(3, 4, 5)\n        >>> b = torch.arange(24.).reshape(4, 3, 2)\n        >>> torch.tensordot(a, b, dims=([1, 0], [0, 1]))\n        tensor([[4400., 4730.],\n                [4532., 4874.],\n                [4664., 5018.],\n                [4796., 5162.],\n                [4928., 5306.]])\n\n        >>> a = torch.randn(3, 4, 5, device='cuda')\n        >>> b = torch.randn(4, 5, 6, device='cuda')\n        >>> c = torch.tensordot(a, b, dims=2).cpu()\n        tensor([[ 8.3504, -2.5436,  6.2922,  2.7556, -1.0732,  3.2741],\n                [ 3.3161,  0.0704,  5.0187, -0.4079, -4.3126,  4.8744],\n                [ 0.8223,  3.9445,  3.2168, -0.2400,  3.4117,  1.7780]])\n\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if (type(a) is not Tensor or type(b) is not Tensor) and has_torch_function((a, b)):\n            return handle_torch_function(tensordot, (a, b), a, b, dims=dims)\n    if isinstance(dims, (list, tuple)) or \\\n       (isinstance(dims, torch.Tensor) and dims.numel() > 1):\n        dims_a, dims_b = dims\n    else:\n        if isinstance(dims, torch.Tensor):\n            dims = dims.item()\n        if dims < 0:\n            raise RuntimeError(\"tensordot expects dims >= 0, but got dims={}\".format(dims))\n        dims_a = list(range(-dims, 0))\n        dims_b = list(range(dims))\n    return _VF.tensordot(a, b, dims_a, dims_b)\n\ndef cartesian_prod(*tensors):\n    \"\"\"Do cartesian product of the given sequence of tensors. The behavior is similar to\n    python's `itertools.product`.\n\n    Arguments:\n        *tensors: any number of 1 dimensional tensors.\n\n    Returns:\n        Tensor: A tensor equivalent to converting all the input tensors into lists,\n            do `itertools.product` on these lists, and finally convert the resulting list\n            into tensor.\n\n    Example::\n\n        >>> a = [1, 2, 3]\n        >>> b = [4, 5]\n        >>> list(itertools.product(a, b))\n        [(1, 4), (1, 5), (2, 4), (2, 5), (3, 4), (3, 5)]\n        >>> tensor_a = torch.tensor(a)\n        >>> tensor_b = torch.tensor(b)\n        >>> torch.cartesian_prod(tensor_a, tensor_b)\n        tensor([[1, 4],\n                [1, 5],\n                [2, 4],\n                [2, 5],\n                [3, 4],\n                [3, 5]])\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):\n            return handle_torch_function(cartesian_prod, tensors, *tensors)\n    return _VF.cartesian_prod(tensors)\n\ndef block_diag(*tensors):\n    \"\"\"Create a block diagonal matrix from provided tensors.\n\n    Arguments:\n        *tensors: One or more tensors with 0, 1, or 2 dimensions.\n\n    Returns:\n        Tensor: A 2 dimensional tensor with all the input tensors arranged in\n            order such that their upper left and lower right corners are\n            diagonally adjacent. All other elements are set to 0.\n\n    Example::\n\n        >>> import torch\n        >>> A = torch.tensor([[0, 1], [1, 0]])\n        >>> B = torch.tensor([[3, 4, 5], [6, 7, 8]])\n        >>> C = torch.tensor(7)\n        >>> D = torch.tensor([1, 2, 3])\n        >>> E = torch.tensor([[4], [5], [6]])\n        >>> torch.block_diag(A, B, C, D, E)\n        tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],\n                [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n                [0, 0, 3, 4, 5, 0, 0, 0, 0, 0],\n                [0, 0, 6, 7, 8, 0, 0, 0, 0, 0],\n                [0, 0, 0, 0, 0, 7, 0, 0, 0, 0],\n                [0, 0, 0, 0, 0, 0, 1, 2, 3, 0],\n                [0, 0, 0, 0, 0, 0, 0, 0, 0, 4],\n                [0, 0, 0, 0, 0, 0, 0, 0, 0, 5],\n                [0, 0, 0, 0, 0, 0, 0, 0, 0, 6]])\n    \"\"\"\n    if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):\n        return handle_torch_function(block_diag, tensors, *tensors)\n    return torch._C._VariableFunctions.block_diag(tensors)\n\n\ndef cdist(x1, x2, p=2., compute_mode='use_mm_for_euclid_dist_if_necessary'):\n    # type: (Tensor, Tensor, float, str) -> (Tensor)\n    r\"\"\"Computes batched the p-norm distance between each pair of the two collections of row vectors.\n\n    Args:\n        x1 (Tensor): input tensor of shape :math:`B \\times P \\times M`.\n        x2 (Tensor): input tensor of shape :math:`B \\times R \\times M`.\n        p: p value for the p-norm distance to calculate between each vector pair\n            :math:`\\in [0, \\infty]`.\n        compute_mode:\n            'use_mm_for_euclid_dist_if_necessary' - will use matrix multiplication approach to calculate\n            euclidean distance (p = 2) if P > 25 or R > 25\n            'use_mm_for_euclid_dist' - will always use matrix multiplication approach to calculate\n            euclidean distance (p = 2)\n            'donot_use_mm_for_euclid_dist' - will never use matrix multiplication approach to calculate\n            euclidean distance (p = 2)\n            Default: use_mm_for_euclid_dist_if_necessary.\n\n    If x1 has shape :math:`B \\times P \\times M` and x2 has shape :math:`B \\times R \\times M` then the\n    output will have shape :math:`B \\times P \\times R`.\n\n    This function is equivalent to `scipy.spatial.distance.cdist(input,'minkowski', p=p)`\n    if :math:`p \\in (0, \\infty)`. When :math:`p = 0` it is equivalent to\n    `scipy.spatial.distance.cdist(input, 'hamming') * M`. When :math:`p = \\infty`, the closest\n    scipy function is `scipy.spatial.distance.cdist(xn, lambda x, y: np.abs(x - y).max())`.\n\n    Example:\n\n        >>> a = torch.tensor([[0.9041,  0.0196], [-0.3108, -2.4423], [-0.4821,  1.059]])\n        >>> a\n        tensor([[ 0.9041,  0.0196],\n                [-0.3108, -2.4423],\n                [-0.4821,  1.0590]])\n        >>> b = torch.tensor([[-2.1763, -0.4713], [-0.6986,  1.3702]])\n        >>> b\n        tensor([[-2.1763, -0.4713],\n                [-0.6986,  1.3702]])\n        >>> torch.cdist(a, b, p=2)\n        tensor([[3.1193, 2.0959],\n                [2.7138, 3.8322],\n                [2.2830, 0.3791]])\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if (type(x1) is not Tensor or type(x2) is not Tensor) and has_torch_function((x1, x2)):\n            return handle_torch_function(\n                cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode)\n    if compute_mode == 'use_mm_for_euclid_dist_if_necessary':\n        return _VF.cdist(x1, x2, p, None)\n    elif compute_mode == 'use_mm_for_euclid_dist':\n        return _VF.cdist(x1, x2, p, 1)\n    elif compute_mode == 'donot_use_mm_for_euclid_dist':\n        return _VF.cdist(x1, x2, p, 2)\n    else:\n        raise ValueError(\"{} is not a valid value for compute_mode\".format(compute_mode))\n\n# TODO: type dim as BroadcastingList when https://github.com/pytorch/pytorch/issues/33782 is fixed\n@overload  # noqa: 749\ndef norm(input, p=\"fro\", dim=None, keepdim=False, out=None, dtype=None):  # noqa: 749\n    # type: (Tensor, str, Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor\n    pass\n\n@overload  # noqa: 749\ndef norm(input, p=\"fro\", dim=None, keepdim=False, out=None, dtype=None):  # noqa: 749\n    # type: (Tensor, Optional[number], Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor\n    pass\n\n@overload  # noqa: 749\ndef norm(input, p=\"fro\", dim=None, keepdim=False, out=None, dtype=None):  # noqa: 749\n    # type: (Tensor, Optional[number], Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor\n    pass\n\n@overload  # noqa: 749\ndef norm(input, p=\"fro\", dim=None, keepdim=False, out=None, dtype=None):  # noqa: 749\n    # type: (Tensor, str, Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor\n    pass\n\ndef norm(input, p=\"fro\", dim=None, keepdim=False, out=None, dtype=None):  # noqa: 749\n    r\"\"\"Returns the matrix norm or vector norm of a given tensor.\n\n    Args:\n        input (Tensor): the input tensor\n        p (int, float, inf, -inf, 'fro', 'nuc', optional): the order of norm. Default: ``'fro'``\n            The following norms can be calculated:\n\n            =====  ============================  ==========================\n            ord    matrix norm                   vector norm\n            =====  ============================  ==========================\n            None   Frobenius norm                2-norm\n            'fro'  Frobenius norm                --\n            'nuc'  nuclear norm                  --\n            Other  as vec norm when dim is None  sum(abs(x)**ord)**(1./ord)\n            =====  ============================  ==========================\n\n        dim (int, 2-tuple of ints, 2-list of ints, optional): If it is an int,\n            vector norm will be calculated, if it is 2-tuple of ints, matrix norm\n            will be calculated. If the value is None, matrix norm will be calculated\n            when the input tensor only has two dimensions, vector norm will be\n            calculated when the input tensor only has one dimension. If the input\n            tensor has more than two dimensions, the vector norm will be applied to\n            last dimension.\n        keepdim (bool, optional): whether the output tensors have :attr:`dim`\n            retained or not. Ignored if :attr:`dim` = ``None`` and\n            :attr:`out` = ``None``. Default: ``False``\n        out (Tensor, optional): the output tensor. Ignored if\n            :attr:`dim` = ``None`` and :attr:`out` = ``None``.\n        dtype (:class:`torch.dtype`, optional): the desired data type of\n            returned tensor. If specified, the input tensor is casted to\n            :attr:'dtype' while performing the operation. Default: None.\n\n\n    Example::\n\n        >>> import torch\n        >>> a = torch.arange(9, dtype= torch.float) - 4\n        >>> b = a.reshape((3, 3))\n        >>> torch.norm(a)\n        tensor(7.7460)\n        >>> torch.norm(b)\n        tensor(7.7460)\n        >>> torch.norm(a, float('inf'))\n        tensor(4.)\n        >>> torch.norm(b, float('inf'))\n        tensor(4.)\n        >>> c = torch.tensor([[ 1, 2, 3],[-1, 1, 4]] , dtype= torch.float)\n        >>> torch.norm(c, dim=0)\n        tensor([1.4142, 2.2361, 5.0000])\n        >>> torch.norm(c, dim=1)\n        tensor([3.7417, 4.2426])\n        >>> torch.norm(c, p=1, dim=1)\n        tensor([6., 6.])\n        >>> d = torch.arange(8, dtype= torch.float).reshape(2,2,2)\n        >>> torch.norm(d, dim=(1,2))\n        tensor([ 3.7417, 11.2250])\n        >>> torch.norm(d[0, :, :]), torch.norm(d[1, :, :])\n        (tensor(3.7417), tensor(11.2250))\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return handle_torch_function(\n                norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype)\n\n    ndim = input.dim()\n\n\n    # catch default case\n    if dim is None and out is None and dtype is None and p is not None:\n        if isinstance(p, str):\n            if p == \"fro\":\n                return _VF.frobenius_norm(input)\n        if not isinstance(p, str):\n            return _VF.norm(input, p)\n\n    # TODO: when https://github.com/pytorch/pytorch/issues/33782 is fixed\n    # remove the overloads where dim is an int and replace with BraodcastingList1\n    # and remove next four lines, replace _dim with dim\n    if dim is not None:\n        if isinstance(dim, int):\n            _dim = [dim]\n        else:\n            _dim = dim\n    else:\n        _dim = None\n\n    if isinstance(p, str):\n        if p == \"fro\":\n            if dtype is not None:\n                raise ValueError(\"dtype argument is not supported in frobenius norm\")\n\n            if _dim is None:\n                _dim = [i for i in range(ndim)]  # noqa: C416 TODO: rewrite as list(range(m))\n            if out is None:\n                return _VF.frobenius_norm(input, _dim, keepdim=keepdim)\n            else:\n                return _VF.frobenius_norm(input, _dim, keepdim=keepdim, out=out)\n        elif p == \"nuc\":\n            if dtype is not None:\n                raise ValueError(\"dtype argument is not supported in nuclear norm\")\n            if _dim is None:\n                if out is None:\n                    return _VF.nuclear_norm(input, keepdim=keepdim)\n                else:\n                    return _VF.nuclear_norm(input, keepdim=keepdim, out=out)\n            else:\n                if out is None:\n                    return _VF.nuclear_norm(input, _dim, keepdim=keepdim)\n                else:\n                    return _VF.nuclear_norm(input, _dim, keepdim=keepdim, out=out)\n        raise RuntimeError(\"only valid string values are 'fro' and 'nuc', found {}\".format(p))\n    else:\n        if _dim is None:\n            _dim = [i for i in range(ndim)]  # noqa: C416 TODO: rewrite as list(range(m))\n\n        if out is None:\n            if dtype is None:\n                return _VF.norm(input, p, _dim, keepdim=keepdim)\n            else:\n                return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype)\n        else:\n            if dtype is None:\n                return _VF.norm(input, p, _dim, keepdim=keepdim, out=out)\n            else:\n                return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype, out=out)\n\ndef chain_matmul(*matrices):\n    r\"\"\"Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed\n    using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms\n    of arithmetic operations (`[CLRS]`_). Note that since this is a function to compute the product, :math:`N`\n    needs to be greater than or equal to 2; if equal to 2 then a trivial matrix-matrix product is returned.\n    If :math:`N` is 1, then this is a no-op - the original matrix is returned as is.\n\n\n    Args:\n        matrices (Tensors...): a sequence of 2 or more 2-D tensors whose product is to be determined.\n\n\n    Returns:\n        Tensor: if the :math:`i^{th}` tensor was of dimensions :math:`p_{i} \\times p_{i + 1}`, then the product\n        would be of dimensions :math:`p_{1} \\times p_{N + 1}`.\n\n    Example::\n\n        >>> a = torch.randn(3, 4)\n        >>> b = torch.randn(4, 5)\n        >>> c = torch.randn(5, 6)\n        >>> d = torch.randn(6, 7)\n        >>> torch.chain_matmul(a, b, c, d)\n        tensor([[ -2.3375,  -3.9790,  -4.1119,  -6.6577,   9.5609, -11.5095,  -3.2614],\n                [ 21.4038,   3.3378,  -8.4982,  -5.2457, -10.2561,  -2.4684,   2.7163],\n                [ -0.9647,  -5.8917,  -2.3213,  -5.2284,  12.8615, -12.2816,  -2.5095]])\n\n    .. _`[CLRS]`: https://mitpress.mit.edu/books/introduction-algorithms-third-edition\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if any(type(t) is not Tensor for t in matrices) and has_torch_function(matrices):\n            return handle_torch_function(chain_matmul, matrices, *matrices)\n    return _VF.chain_matmul(matrices)\n\n\ndef _lu_impl(A, pivot=True, get_infos=False, out=None):\n    # type: (Tensor, bool, bool, Any) -> Tuple[Tensor, Tensor, Tensor]\n    r\"\"\"Computes the LU factorization of a matrix or batches of matrices\n    :attr:`A`. Returns a tuple containing the LU factorization and\n    pivots of :attr:`A`.  Pivoting is done if :attr:`pivot` is set to\n    ``True``.\n\n    .. note::\n        The pivots returned by the function are 1-indexed. If :attr:`pivot` is ``False``,\n        then the returned pivots is a tensor filled with zeros of the appropriate size.\n\n    .. note::\n        LU factorization with :attr:`pivot` = ``False`` is not available for CPU, and attempting\n        to do so will throw an error. However, LU factorization with :attr:`pivot` = ``False`` is\n        available for CUDA.\n\n    .. note::\n        This function does not check if the factorization was successful or not if\n        :attr:`get_infos` is ``True`` since the status of the factorization is present in the\n        third element of the return tuple.\n\n    .. note::\n        In the case of batches of square matrices with size less or\n        equal to 32 on a CUDA device, the LU factorization is repeated\n        for singular matrices due to the bug in the MAGMA library (see\n        magma issue 13).\n\n    .. note::\n       ``L``, ``U``, and ``P`` can be derived using :func:`torch.lu_unpack`.\n\n    Arguments:\n        A (Tensor): the tensor to factor of size :math:`(*, m, n)`\n        pivot (bool, optional): controls whether pivoting is done. Default: ``True``\n        get_infos (bool, optional): if set to ``True``, returns an info IntTensor.\n                                    Default: ``False``\n        out (tuple, optional): optional output tuple. If :attr:`get_infos` is ``True``,\n                               then the elements in the tuple are Tensor, IntTensor,\n                               and IntTensor. If :attr:`get_infos` is ``False``, then the\n                               elements in the tuple are Tensor, IntTensor. Default: ``None``\n\n    Returns:\n        (Tensor, IntTensor, IntTensor (optional)): A tuple of tensors containing\n\n            - **factorization** (*Tensor*): the factorization of size :math:`(*, m, n)`\n\n            - **pivots** (*IntTensor*): the pivots of size :math:`(*, m)`\n\n            - **infos** (*IntTensor*, *optional*): if :attr:`get_infos` is ``True``, this is a tensor of\n              size :math:`(*)` where non-zero values indicate whether factorization for the matrix or\n              each minibatch has succeeded or failed\n\n    Example::\n\n        >>> A = torch.randn(2, 3, 3)\n        >>> A_LU, pivots = torch.lu(A)\n        >>> A_LU\n        tensor([[[ 1.3506,  2.5558, -0.0816],\n                 [ 0.1684,  1.1551,  0.1940],\n                 [ 0.1193,  0.6189, -0.5497]],\n\n                [[ 0.4526,  1.2526, -0.3285],\n                 [-0.7988,  0.7175, -0.9701],\n                 [ 0.2634, -0.9255, -0.3459]]])\n        >>> pivots\n        tensor([[ 3,  3,  3],\n                [ 3,  3,  3]], dtype=torch.int32)\n        >>> A_LU, pivots, info = torch.lu(A, get_infos=True)\n        >>> if info.nonzero().size(0) == 0:\n        ...   print('LU factorization succeeded for all samples!')\n        LU factorization succeeded for all samples!\n    \"\"\"\n    # If get_infos is True, then we don't need to check for errors and vice versa\n    return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))\n\ndef _check_list_size(out_len, get_infos, out):\n    # type: (int, bool, List[Tensor]) -> None\n    get_infos_int = 1 if get_infos else 0\n    if out_len - get_infos_int != 2:\n        raise TypeError(\"expected tuple of {} elements but got {}\"\n                        .format(2 + int(get_infos), len(out_len)))\n    if not isinstance(out, (tuple, list)):\n        raise TypeError(\"argument 'out' must be tuple of Tensors, not {}\"\n                        .format(type(out).__name__))\n\ndef _lu_with_infos(A, pivot=True, get_infos=False, out=None):\n    # type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor]\n    if not torch.jit.is_scripting():\n        if type(A) is not Tensor and has_torch_function((A,)):\n            return handle_torch_function(\n                lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)\n    result = _lu_impl(A, pivot, get_infos, out)\n    if out is not None:\n        _check_list_size(len(out), get_infos, out)\n        for i in range(len(out)):\n            out[i].resize_as_(result[i]).copy_(result[i])\n        return out\n    else:\n        return result  # A_LU, pivots, infos\n\ndef _lu_no_infos(A, pivot=True, get_infos=False, out=None):\n    # type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]\n    # need to check for torch_function here so that we exit if\n    if not torch.jit.is_scripting():\n        if type(A) is not Tensor and has_torch_function((A,)):\n            return handle_torch_function(\n                lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)\n    result = _lu_impl(A, pivot, get_infos, out)\n    if out is not None:\n        _check_list_size(len(out), get_infos, out)\n        for i in range(len(out)):\n            out[i].resize_as_(result[i]).copy_(result[i])\n        return out\n    else:\n        return result[0], result[1]  # A_LU, pivots\n\n# The return type of lu depends on `get_infos`, so in order to resolve the output type\n# of lu in TorchScript we need to statically know the value of `get_infos`\nlu = boolean_dispatch(\n    arg_name='get_infos',\n    arg_index=2,\n    default=False,\n    if_true=_lu_with_infos,\n    if_false=_lu_no_infos,\n    module_name=__name__,\n    func_name='lu')\nlu.__doc__ = _lu_impl.__doc__\n\ndef align_tensors(*tensors):\n    raise RuntimeError('`align_tensors` not yet implemented.')\n"
  },
  {
    "path": "patches/pytorch/1.7.0/functional.diff",
    "content": "4a5,7\n> import librosa  # STFT patch for aarch64\n> import numpy as np\n> \n515,516c518,528\n<     return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore\n<                     normalized, onesided, return_complex)\n---\n>         \n>     # STFT patch for aarch64\n>     # https://stackoverflow.com/a/66872148\n>     librosa_stft = librosa.stft(input.cpu().detach().numpy().reshape(-1), n_fft, hop_length, win_length, window=\"hann\", center=center, pad_mode=pad_mode)\n>     librosa_stft = np.array([[a.real, a.imag] for a in librosa_stft])\n>     librosa_stft = np.transpose(librosa_stft, axes=[0, 2, 1])\n>     librosa_stft = np.expand_dims(librosa_stft, 0)\n>     librosa_stft = torch.from_numpy(librosa_stft)\n>     return librosa_stft\n>     #return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore\n>     #                normalized, onesided, return_complex)\n"
  },
  {
    "path": "patches/pytorch/1.7.0/functional.original.py",
    "content": "from typing import (\n    Tuple, Optional, Union, Any, Sequence, TYPE_CHECKING\n)\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.types import _size\nfrom ._lowrank import svd_lowrank, pca_lowrank\nfrom .overrides import has_torch_function, handle_torch_function\nfrom ._jit_internal import boolean_dispatch, List\nfrom ._jit_internal import _overload as overload\n\nTensor = torch.Tensor\nfrom torch import _VF\n\n__all__ = [\n    'atleast_1d',\n    'atleast_2d',\n    'atleast_3d',\n    'align_tensors',\n    'broadcast_tensors',\n    'cartesian_prod',\n    'block_diag',\n    'cdist',\n    'chain_matmul',\n    'einsum',\n    'istft',\n    'lu',\n    'lu_unpack',\n    'norm',\n    'meshgrid',\n    'pca_lowrank',\n    'split',\n    'stft',\n    'svd_lowrank',\n    'tensordot',\n    'unique',\n    'unique_consecutive',\n]\n\n\ndef broadcast_tensors(*tensors):\n    r\"\"\"broadcast_tensors(*tensors) -> List of Tensors\n\n    Broadcasts the given tensors according to :ref:`broadcasting-semantics`.\n\n    Args:\n        *tensors: any number of tensors of the same type\n\n    .. warning::\n\n        More than one element of a broadcasted tensor may refer to a single\n        memory location. As a result, in-place operations (especially ones that\n        are vectorized) may result in incorrect behavior. If you need to write\n        to the tensors, please clone them first.\n\n    Example::\n\n        >>> x = torch.arange(3).view(1, 3)\n        >>> y = torch.arange(2).view(2, 1)\n        >>> a, b = torch.broadcast_tensors(x, y)\n        >>> a.size()\n        torch.Size([2, 3])\n        >>> a\n        tensor([[0, 1, 2],\n                [0, 1, 2]])\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):\n            return handle_torch_function(broadcast_tensors, tensors, *tensors)\n    return _VF.broadcast_tensors(tensors)  # type: ignore\n\n\ndef split(tensor, split_size_or_sections, dim=0):\n    r\"\"\"Splits the tensor into chunks. Each chunk is a view of the original tensor.\n\n    If :attr:`split_size_or_sections` is an integer type, then :attr:`tensor` will\n    be split into equally sized chunks (if possible). Last chunk will be smaller if\n    the tensor size along the given dimension :attr:`dim` is not divisible by\n    :attr:`split_size`.\n\n    If :attr:`split_size_or_sections` is a list, then :attr:`tensor` will be split\n    into ``len(split_size_or_sections)`` chunks with sizes in :attr:`dim` according\n    to :attr:`split_size_or_sections`.\n\n    Arguments:\n        tensor (Tensor): tensor to split.\n        split_size_or_sections (int) or (list(int)): size of a single chunk or\n            list of sizes for each chunk\n        dim (int): dimension along which to split the tensor.\n\n    Example::\n        >>> a = torch.arange(10).reshape(5,2)\n        >>> a\n        tensor([[0, 1],\n                [2, 3],\n                [4, 5],\n                [6, 7],\n                [8, 9]])\n        >>> torch.split(a, 2)\n        (tensor([[0, 1],\n                 [2, 3]]),\n         tensor([[4, 5],\n                 [6, 7]]),\n         tensor([[8, 9]]))\n        >>> torch.split(a, [1,4])\n        (tensor([[0, 1]]),\n         tensor([[2, 3],\n                 [4, 5],\n                 [6, 7],\n                 [8, 9]]))\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if type(tensor) is not Tensor and has_torch_function((tensor,)):\n            return handle_torch_function(split, (tensor,), tensor, split_size_or_sections,\n                                         dim=dim)\n    # Overwriting reason:\n    # This dispatches to two ATen functions depending on the type of\n    # split_size_or_sections. The branching code is in tensor.py, which we\n    # call here.\n    return tensor.split(split_size_or_sections, dim)\n\n\nif TYPE_CHECKING:\n    _Indices = _size\nelse:\n    _Indices = List[int]\n\n\n# equivalent to itertools.product(indices)\ndef _indices_product(indices: _Indices) -> List[List[int]]:\n    empty_list = torch.jit.annotate(List[int], [])\n    result = [empty_list]\n    for idx in indices:\n        result_temp = torch.jit.annotate(List[List[int]], [])\n        for res in result:\n            for i in range(idx):\n                result_temp.append(res + [i])\n        result = result_temp\n    return result\n\n\ndef _index_tensor_with_indices_list(tensor, indices):\n    # type: (Tensor, List[int]) -> Tensor\n    out = tensor\n    for index in indices:\n        out = out[index]\n    return out\n\n\ndef lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):\n    # type: (Tensor, Tensor, bool, bool) ->  (Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]])\n    r\"\"\"Unpacks the data and pivots from a LU factorization of a tensor.\n\n    Returns a tuple of tensors as ``(the pivots, the L tensor, the U tensor)``.\n\n    Arguments:\n        LU_data (Tensor): the packed LU factorization data\n        LU_pivots (Tensor): the packed LU factorization pivots\n        unpack_data (bool): flag indicating if the data should be unpacked\n        unpack_pivots (bool): flag indicating if the pivots should be unpacked\n\n    Examples::\n\n        >>> A = torch.randn(2, 3, 3)\n        >>> A_LU, pivots = A.lu()\n        >>> P, A_L, A_U = torch.lu_unpack(A_LU, pivots)\n        >>>\n        >>> # can recover A from factorization\n        >>> A_ = torch.bmm(P, torch.bmm(A_L, A_U))\n\n        >>> # LU factorization of a rectangular matrix:\n        >>> A = torch.randn(2, 3, 2)\n        >>> A_LU, pivots = A.lu()\n        >>> P, A_L, A_U = torch.lu_unpack(A_LU, pivots)\n        >>> P\n        tensor([[[1., 0., 0.],\n                 [0., 1., 0.],\n                 [0., 0., 1.]],\n\n                [[0., 0., 1.],\n                 [0., 1., 0.],\n                 [1., 0., 0.]]])\n        >>> A_L\n        tensor([[[ 1.0000,  0.0000],\n                 [ 0.4763,  1.0000],\n                 [ 0.3683,  0.1135]],\n\n                [[ 1.0000,  0.0000],\n                 [ 0.2957,  1.0000],\n                 [-0.9668, -0.3335]]])\n        >>> A_U\n        tensor([[[ 2.1962,  1.0881],\n                 [ 0.0000, -0.8681]],\n\n                [[-1.0947,  0.3736],\n                 [ 0.0000,  0.5718]]])\n        >>> A_ = torch.bmm(P, torch.bmm(A_L, A_U))\n        >>> torch.norm(A_ - A)\n        tensor(2.9802e-08)\n    \"\"\"\n    if not torch.jit.is_scripting():\n        tens_ops = (LU_data, LU_pivots)\n        if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):\n            return handle_torch_function(\n                lu_unpack, tens_ops, LU_data, LU_pivots, unpack_data=unpack_data,\n                unpack_pivots=unpack_pivots)\n    shape = LU_data.shape\n    # In generalized LU factorization, the following shape relations hold:\n    #   A.shape[-2:] == (m, n)\n    #   P.shape[-2:] == (m, m)\n    #   L.shape[-2:] == (m, k)\n    #   U.shape[-2:] == (k, n)\n    # where k = min(m, n)\n    m, n = shape[-2:]\n    k = min(m, n)\n    if unpack_data:\n        U: Optional[Tensor] = LU_data.triu()\n        assert U is not None\n        if m != k:\n            U = U.narrow(-2, 0, k)\n        L: Optional[Tensor] = LU_data.tril()\n        assert L is not None\n        if k != n:\n            L = L.narrow(-1, 0, k)\n        L.diagonal(dim1=-2, dim2=-1).fill_(1)\n    else:\n        L = U = None\n\n    if unpack_pivots:\n        LU_pivots_zero_idx = LU_pivots - 1\n        if LU_data.dim() > 2:\n            P: Optional[Tensor] = torch.eye(m, device=LU_data.device,\n                                            dtype=LU_data.dtype) \\\n                .expand(shape[:-1] + (m,)) \\\n                .clone(memory_format=torch.contiguous_format)\n            assert P is not None\n\n            # TODO: rewrite when TorchScript supports product and map as\n            # product(*map(lambda x: list(range(x)), shape[:-2])) when issue 33781 is fixed\n            indices = _indices_product(shape[:-2])\n            for idx in indices:\n                final_order = [i for i in range(m)]  # noqa: C416 TODO: rewrite as list(range(m))\n                for k, j in enumerate(_index_tensor_with_indices_list(LU_pivots_zero_idx, idx)):\n                    final_order[k], final_order[j] = final_order[j], final_order[k]\n                # TODO: remove _index_tensor_with_indices_list when TorchScript supports indexing Tensor with list\n                p_idx = _index_tensor_with_indices_list(P, idx)\n                p_idx.copy_(p_idx.index_select(1, torch.as_tensor(final_order, device=LU_pivots.device)))\n        else:\n            P = torch.eye(m, device=LU_data.device, dtype=LU_data.dtype)\n            final_order = [i for i in range(m)]  # noqa: C416 TODO: rewrite as list(range(m))\n            for k, j, in enumerate(LU_pivots_zero_idx):\n                final_order[k], final_order[j] = final_order[j], final_order[k]\n            P = P.index_select(1, torch.as_tensor(final_order, device=LU_pivots.device))\n    else:\n        P = None\n\n    return P, L, U\n\n\ndef einsum(equation, *operands):\n    r\"\"\"einsum(equation, *operands) -> Tensor\n\nThis function provides a way of computing multilinear expressions (i.e. sums of products) using the\nEinstein summation convention.\n\nArgs:\n    equation (string): The equation is given in terms of lower case letters (indices) to be associated\n           with each dimension of the operands and result. The left hand side lists the operands\n           dimensions, separated by commas. There should be one index letter per tensor dimension.\n           The right hand side follows after `->` and gives the indices for the output.\n           If the `->` and right hand side are omitted, it implicitly defined as the alphabetically\n           sorted list of all indices appearing exactly once in the left hand side.\n           The indices not apprearing in the output are summed over after multiplying the operands\n           entries.\n           If an index appears several times for the same operand, a diagonal is taken.\n           Ellipses `...` represent a fixed number of dimensions. If the right hand side is inferred,\n           the ellipsis dimensions are at the beginning of the output.\n    operands (Tensor): The operands to compute the Einstein sum of.\n\n.. note::\n\n    This function does not optimize the given expression, so a different formula for the same computation may\n    run faster or consume less memory. Projects like opt_einsum (https://optimized-einsum.readthedocs.io/en/stable/)\n    can optimize the formula for you.\n\nExamples::\n\n    >>> x = torch.randn(5)\n    >>> y = torch.randn(4)\n    >>> torch.einsum('i,j->ij', x, y)  # outer product\n    tensor([[-0.0570, -0.0286, -0.0231,  0.0197],\n            [ 1.2616,  0.6335,  0.5113, -0.4351],\n            [ 1.4452,  0.7257,  0.5857, -0.4984],\n            [-0.4647, -0.2333, -0.1883,  0.1603],\n            [-1.1130, -0.5588, -0.4510,  0.3838]])\n\n\n    >>> A = torch.randn(3,5,4)\n    >>> l = torch.randn(2,5)\n    >>> r = torch.randn(2,4)\n    >>> torch.einsum('bn,anm,bm->ba', l, A, r) # compare torch.nn.functional.bilinear\n    tensor([[-0.3430, -5.2405,  0.4494],\n            [ 0.3311,  5.5201, -3.0356]])\n\n\n    >>> As = torch.randn(3,2,5)\n    >>> Bs = torch.randn(3,5,4)\n    >>> torch.einsum('bij,bjk->bik', As, Bs) # batch matrix multiplication\n    tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],\n             [-1.6706, -0.8097, -0.8025, -2.1183]],\n\n            [[ 4.2239,  0.3107, -0.5756, -0.2354],\n             [-1.4558, -0.3460,  1.5087, -0.8530]],\n\n            [[ 2.8153,  1.8787, -4.3839, -1.2112],\n             [ 0.3728, -2.1131,  0.0921,  0.8305]]])\n\n    >>> A = torch.randn(3, 3)\n    >>> torch.einsum('ii->i', A) # diagonal\n    tensor([-0.7825,  0.8291, -0.1936])\n\n    >>> A = torch.randn(4, 3, 3)\n    >>> torch.einsum('...ii->...i', A) # batch diagonal\n    tensor([[-1.0864,  0.7292,  0.0569],\n            [-0.9725, -1.0270,  0.6493],\n            [ 0.5832, -1.1716, -1.5084],\n            [ 0.4041, -1.1690,  0.8570]])\n\n    >>> A = torch.randn(2, 3, 4, 5)\n    >>> torch.einsum('...ij->...ji', A).shape # batch permute\n    torch.Size([2, 3, 5, 4])\n\"\"\"\n    if not torch.jit.is_scripting():\n        if any(type(t) is not Tensor for t in operands) and has_torch_function(operands):\n            return handle_torch_function(einsum, operands, equation, *operands)\n    if len(operands) == 1 and isinstance(operands[0], (list, tuple)):\n        # the old interface of passing the operands as one list argument\n        _operands = operands[0]\n        # recurse incase operands contains value that has torch function\n        # in the original implementation this line is omitted\n        return einsum(equation, *_operands)\n\n    return _VF.einsum(equation, operands)  # type: ignore\n\n\nif TYPE_CHECKING:\n    # The JIT doesn't understand Union, so only add type annotation for mypy\n    def meshgrid(*tensors: Union[Tensor, List[Tensor]]) -> Tuple[Tensor, ...]:\n        return _meshgrid(*tensors)\nelse:\n    def meshgrid(*tensors):\n        return _meshgrid(*tensors)\n\n\ndef _meshgrid(*tensors):\n    r\"\"\"Take :math:`N` tensors, each of which can be either scalar or 1-dimensional\nvector, and create :math:`N` N-dimensional grids, where the :math:`i` :sup:`th` grid is defined by\nexpanding the :math:`i` :sup:`th` input over dimensions defined by other inputs.\n\n\n    Args:\n        tensors (list of Tensor): list of scalars or 1 dimensional tensors. Scalars will be\n        treated as tensors of size :math:`(1,)` automatically\n\n    Returns:\n        seq (sequence of Tensors): If the input has :math:`k` tensors of size\n        :math:`(N_1,), (N_2,), \\ldots , (N_k,)`, then the output would also have :math:`k` tensors,\n        where all tensors are of size :math:`(N_1, N_2, \\ldots , N_k)`.\n\n    Example::\n\n        >>> x = torch.tensor([1, 2, 3])\n        >>> y = torch.tensor([4, 5, 6])\n        >>> grid_x, grid_y = torch.meshgrid(x, y)\n        >>> grid_x\n        tensor([[1, 1, 1],\n                [2, 2, 2],\n                [3, 3, 3]])\n        >>> grid_y\n        tensor([[4, 5, 6],\n                [4, 5, 6],\n                [4, 5, 6]])\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):\n            return handle_torch_function(meshgrid, tensors, *tensors)\n    if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)):\n        # the old interface of passing the operands as one list argument\n        tensors = tensors[0]  # type: ignore\n    return _VF.meshgrid(tensors)  # type: ignore\n\n\ndef stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,\n         win_length: Optional[int] = None, window: Optional[Tensor] = None,\n         center: bool = True, pad_mode: str = 'reflect', normalized: bool = False,\n         onesided: Optional[bool] = None,\n         return_complex: Optional[bool] = None) -> Tensor:\n    r\"\"\"Short-time Fourier transform (STFT).\n\n    .. warning::\n        Setting :attr:`return_complex` explicitly will be required in a future\n        PyTorch release. Set it to False to preserve the current behavior or\n        True to return a complex output.\n\n    The STFT computes the Fourier transform of short overlapping windows of the\n    input. This giving frequency components of the signal as they change over\n    time. The interface of this function is modeled after the librosa_ stft function.\n\n    .. _librosa: https://librosa.org/doc/latest/generated/librosa.stft.html\n\n    Ignoring the optional batch dimension, this method computes the following\n    expression:\n\n    .. math::\n        X[m, \\omega] = \\sum_{k = 0}^{\\text{win\\_length-1}}%\n                            \\text{window}[k]\\ \\text{input}[m \\times \\text{hop\\_length} + k]\\ %\n                            \\exp\\left(- j \\frac{2 \\pi \\cdot \\omega k}{\\text{win\\_length}}\\right),\n\n    where :math:`m` is the index of the sliding window, and :math:`\\omega` is\n    the frequency that :math:`0 \\leq \\omega < \\text{n\\_fft}`. When\n    :attr:`onesided` is the default value ``True``,\n\n    * :attr:`input` must be either a 1-D time sequence or a 2-D batch of time\n      sequences.\n\n    * If :attr:`hop_length` is ``None`` (default), it is treated as equal to\n      ``floor(n_fft / 4)``.\n\n    * If :attr:`win_length` is ``None`` (default), it is treated as equal to\n      :attr:`n_fft`.\n\n    * :attr:`window` can be a 1-D tensor of size :attr:`win_length`, e.g., from\n      :meth:`torch.hann_window`. If :attr:`window` is ``None`` (default), it is\n      treated as if having :math:`1` everywhere in the window. If\n      :math:`\\text{win\\_length} < \\text{n\\_fft}`, :attr:`window` will be padded on\n      both sides to length :attr:`n_fft` before being applied.\n\n    * If :attr:`center` is ``True`` (default), :attr:`input` will be padded on\n      both sides so that the :math:`t`-th frame is centered at time\n      :math:`t \\times \\text{hop\\_length}`. Otherwise, the :math:`t`-th frame\n      begins at time  :math:`t \\times \\text{hop\\_length}`.\n\n    * :attr:`pad_mode` determines the padding method used on :attr:`input` when\n      :attr:`center` is ``True``. See :meth:`torch.nn.functional.pad` for\n      all available options. Default is ``\"reflect\"``.\n\n    * If :attr:`onesided` is ``True`` (default for real input), only values for\n      :math:`\\omega` in :math:`\\left[0, 1, 2, \\dots, \\left\\lfloor\n      \\frac{\\text{n\\_fft}}{2} \\right\\rfloor + 1\\right]` are returned because\n      the real-to-complex Fourier transform satisfies the conjugate symmetry,\n      i.e., :math:`X[m, \\omega] = X[m, \\text{n\\_fft} - \\omega]^*`.\n      Note if the input or window tensors are complex, then :attr:`onesided`\n      output is not possible.\n\n    * If :attr:`normalized` is ``True`` (default is ``False``), the function\n      returns the normalized STFT results, i.e., multiplied by :math:`(\\text{frame\\_length})^{-0.5}`.\n\n    * If :attr:`return_complex` is ``True`` (default if input is complex), the\n      return is a ``input.dim() + 1`` dimensional complex tensor. If ``False``,\n      the output is a ``input.dim() + 2`` dimensional real tensor where the last\n      dimension represents the real and imaginary components.\n\n    Returns either a complex tensor of size :math:`(* \\times N \\times T)` if\n    :attr:`return_complex` is true, or a real tensor of size :math:`(* \\times N\n    \\times T \\times 2)`. Where :math:`*` is the optional batch size of\n    :attr:`input`, :math:`N` is the number of frequencies where STFT is applied\n    and :math:`T` is the total number of frames used.\n\n    .. warning::\n      This function changed signature at version 0.4.1. Calling with the\n      previous signature may cause error or return incorrect result.\n\n    Arguments:\n        input (Tensor): the input tensor\n        n_fft (int): size of Fourier transform\n        hop_length (int, optional): the distance between neighboring sliding window\n            frames. Default: ``None`` (treated as equal to ``floor(n_fft / 4)``)\n        win_length (int, optional): the size of window frame and STFT filter.\n            Default: ``None``  (treated as equal to :attr:`n_fft`)\n        window (Tensor, optional): the optional window function.\n            Default: ``None`` (treated as window of all :math:`1` s)\n        center (bool, optional): whether to pad :attr:`input` on both sides so\n            that the :math:`t`-th frame is centered at time :math:`t \\times \\text{hop\\_length}`.\n            Default: ``True``\n        pad_mode (string, optional): controls the padding method used when\n            :attr:`center` is ``True``. Default: ``\"reflect\"``\n        normalized (bool, optional): controls whether to return the normalized STFT results\n             Default: ``False``\n        onesided (bool, optional): controls whether to return half of results to\n            avoid redundancy for real inputs.\n            Default: ``True`` for real :attr:`input` and :attr:`window`, ``False`` otherwise.\n        return_complex (bool, optional): whether to return a complex tensor, or\n            a real tensor with an extra last dimension for the real and\n            imaginary components.\n\n    Returns:\n        Tensor: A tensor containing the STFT result with shape described above\n\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return handle_torch_function(\n                stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length,\n                window=window, center=center, pad_mode=pad_mode, normalized=normalized,\n                onesided=onesided, return_complex=return_complex)\n    # TODO: after having proper ways to map Python strings to ATen Enum, move\n    #       this and F.pad to ATen.\n    if center:\n        signal_dim = input.dim()\n        extended_shape = [1] * (3 - signal_dim) + list(input.size())\n        pad = int(n_fft // 2)\n        input = F.pad(input.view(extended_shape), (pad, pad), pad_mode)\n        input = input.view(input.shape[-signal_dim:])\n    return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore\n                    normalized, onesided, return_complex)\n\ndef istft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,\n          win_length: Optional[int] = None, window: Optional[Tensor] = None,\n          center: bool = True, normalized: bool = False,\n          onesided: Optional[bool] = None, length: Optional[int] = None,\n          return_complex: bool = False) -> Tensor:\n    r\"\"\"Inverse short time Fourier Transform. This is expected to be the inverse of :func:`~torch.stft`.\n    It has the same parameters (+ additional optional parameter of :attr:`length`) and it should return the\n    least squares estimation of the original signal. The algorithm will check using the NOLA condition (\n    nonzero overlap).\n\n    Important consideration in the parameters :attr:`window` and :attr:`center` so that the envelop\n    created by the summation of all the windows is never zero at certain point in time. Specifically,\n    :math:`\\sum_{t=-\\infty}^{\\infty} |w|^2[n-t\\times hop\\_length] \\cancel{=} 0`.\n\n    Since :func:`~torch.stft` discards elements at the end of the signal if they do not fit in a frame,\n    ``istft`` may return a shorter signal than the original signal (can occur if :attr:`center` is False\n    since the signal isn't padded).\n\n    If :attr:`center` is ``True``, then there will be padding e.g. ``'constant'``, ``'reflect'``, etc.\n    Left padding can be trimmed off exactly because they can be calculated but right padding cannot be\n    calculated without additional information.\n\n    Example: Suppose the last window is:\n    ``[17, 18, 0, 0, 0]`` vs ``[18, 0, 0, 0, 0]``\n\n    The :attr:`n_fft`, :attr:`hop_length`, :attr:`win_length` are all the same which prevents the calculation\n    of right padding. These additional values could be zeros or a reflection of the signal so providing\n    :attr:`length` could be useful. If :attr:`length` is ``None`` then padding will be aggressively removed\n    (some loss of signal).\n\n    [1] D. W. Griffin and J. S. Lim, \"Signal estimation from modified short-time Fourier transform,\"\n    IEEE Trans. ASSP, vol.32, no.2, pp.236-243, Apr. 1984.\n\n    Arguments:\n        input (Tensor): The input tensor. Expected to be output of :func:`~torch.stft`,\n            can either be complex (``channel``, ``fft_size``, ``n_frame``), or real\n            (``channel``, ``fft_size``, ``n_frame``, 2) where the ``channel``\n            dimension is optional.\n        n_fft (int): Size of Fourier transform\n        hop_length (Optional[int]): The distance between neighboring sliding window frames.\n            (Default: ``n_fft // 4``)\n        win_length (Optional[int]): The size of window frame and STFT filter. (Default: ``n_fft``)\n        window (Optional[torch.Tensor]): The optional window function.\n            (Default: ``torch.ones(win_length)``)\n        center (bool): Whether :attr:`input` was padded on both sides so that the :math:`t`-th frame is\n            centered at time :math:`t \\times \\text{hop\\_length}`.\n            (Default: ``True``)\n        normalized (bool): Whether the STFT was normalized. (Default: ``False``)\n        onesided (Optional[bool]): Whether the STFT was onesided.\n            (Default: ``True`` if ``n_fft != fft_size`` in the input size)\n        length (Optional[int]): The amount to trim the signal by (i.e. the\n            original signal length). (Default: whole signal)\n        return_complex (Optional[bool]):\n            Whether the output should be complex, or if the input should be\n            assumed to derive from a real signal and window.\n            Note that this is incompatible with ``onesided=True``.\n            (Default: ``False``)\n\n    Returns:\n        Tensor: Least squares estimation of the original signal of size (..., signal_length)\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return handle_torch_function(\n                istft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length,\n                window=window, center=center, normalized=normalized, onesided=onesided,\n                length=length, return_complex=return_complex)\n\n    return _VF.istft(input, n_fft, hop_length, win_length, window, center,  # type: ignore\n                     normalized, onesided, length, return_complex)\n\n\ndel torch.unique_dim\n\n\nif TYPE_CHECKING:\n    # These _impl functions return a variable number of tensors as output with\n    # __torch_function__; tuple unpacking is done already rather than being\n    # done by the caller of the _impl function\n    _unique_impl_out = Any\nelse:\n    _unique_impl_out = Tuple[Tensor, Tensor, Tensor]\n\n\ndef _unique_impl(input: Tensor, sorted: bool = True,\n                 return_inverse: bool = False, return_counts: bool = False,\n                 dim: Optional[int] = None) -> _unique_impl_out:\n    r\"\"\"Returns the unique elements of the input tensor.\n\n    .. note:: This function is different from :func:`torch.unique_consecutive` in the sense that\n        this function also eliminates non-consecutive duplicate values.\n\n    .. note:: Currently in the CUDA implementation and the CPU implementation when dim is specified,\n        `torch.unique` always sort the tensor at the beginning regardless of the `sort` argument.\n        Sorting could be slow, so if your input tensor is already sorted, it is recommended to use\n        :func:`torch.unique_consecutive` which avoids the sorting.\n\n    Arguments:\n        input (Tensor): the input tensor\n        sorted (bool): Whether to sort the unique elements in ascending order\n            before returning as output.\n        return_inverse (bool): Whether to also return the indices for where\n            elements in the original input ended up in the returned unique list.\n        return_counts (bool): Whether to also return the counts for each unique\n            element.\n        dim (int): the dimension to apply unique. If ``None``, the unique of the\n            flattened input is returned. default: ``None``\n\n    Returns:\n        (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing\n\n            - **output** (*Tensor*): the output list of unique scalar elements.\n            - **inverse_indices** (*Tensor*): (optional) if\n              :attr:`return_inverse` is True, there will be an additional\n              returned tensor (same shape as input) representing the indices\n              for where elements in the original input map to in the output;\n              otherwise, this function will only return a single tensor.\n            - **counts** (*Tensor*): (optional) if\n              :attr:`return_counts` is True, there will be an additional\n              returned tensor (same shape as output or output.size(dim),\n              if dim was specified) representing the number of occurrences\n              for each unique value or tensor.\n\n    Example::\n\n        >>> output = torch.unique(torch.tensor([1, 3, 2, 3], dtype=torch.long))\n        >>> output\n        tensor([ 2,  3,  1])\n\n        >>> output, inverse_indices = torch.unique(\n                torch.tensor([1, 3, 2, 3], dtype=torch.long), sorted=True, return_inverse=True)\n        >>> output\n        tensor([ 1,  2,  3])\n        >>> inverse_indices\n        tensor([ 0,  2,  1,  2])\n\n        >>> output, inverse_indices = torch.unique(\n                torch.tensor([[1, 3], [2, 3]], dtype=torch.long), sorted=True, return_inverse=True)\n        >>> output\n        tensor([ 1,  2,  3])\n        >>> inverse_indices\n        tensor([[ 0,  2],\n                [ 1,  2]])\n\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return handle_torch_function(\n                unique, (input,), input, sorted=sorted, return_inverse=return_inverse,\n                return_counts=return_counts, dim=dim)\n\n    if dim is not None:\n        output, inverse_indices, counts = _VF.unique_dim(  # type: ignore\n            input,\n            dim,\n            sorted=sorted,\n            return_inverse=return_inverse,\n            return_counts=return_counts,\n        )\n    else:\n        output, inverse_indices, counts = torch._unique2(\n            input,\n            sorted=sorted,\n            return_inverse=return_inverse,\n            return_counts=return_counts,\n        )\n    return output, inverse_indices, counts\n\n\ndef _unique_consecutive_impl(input: Tensor, return_inverse: bool = False,\n                             return_counts: bool = False,\n                             dim: Optional[int] = None) -> _unique_impl_out:\n    r\"\"\"Eliminates all but the first element from every consecutive group of equivalent elements.\n\n    .. note:: This function is different from :func:`torch.unique` in the sense that this function\n        only eliminates consecutive duplicate values. This semantics is similar to `std::unique`\n        in C++.\n\n    Arguments:\n        input (Tensor): the input tensor\n        return_inverse (bool): Whether to also return the indices for where\n            elements in the original input ended up in the returned unique list.\n        return_counts (bool): Whether to also return the counts for each unique\n            element.\n        dim (int): the dimension to apply unique. If ``None``, the unique of the\n            flattened input is returned. default: ``None``\n\n    Returns:\n        (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing\n\n            - **output** (*Tensor*): the output list of unique scalar elements.\n            - **inverse_indices** (*Tensor*): (optional) if\n              :attr:`return_inverse` is True, there will be an additional\n              returned tensor (same shape as input) representing the indices\n              for where elements in the original input map to in the output;\n              otherwise, this function will only return a single tensor.\n            - **counts** (*Tensor*): (optional) if\n              :attr:`return_counts` is True, there will be an additional\n              returned tensor (same shape as output or output.size(dim),\n              if dim was specified) representing the number of occurrences\n              for each unique value or tensor.\n\n    Example::\n\n        >>> x = torch.tensor([1, 1, 2, 2, 3, 1, 1, 2])\n        >>> output = torch.unique_consecutive(x)\n        >>> output\n        tensor([1, 2, 3, 1, 2])\n\n        >>> output, inverse_indices = torch.unique_consecutive(x, return_inverse=True)\n        >>> output\n        tensor([1, 2, 3, 1, 2])\n        >>> inverse_indices\n        tensor([0, 0, 1, 1, 2, 3, 3, 4])\n\n        >>> output, counts = torch.unique_consecutive(x, return_counts=True)\n        >>> output\n        tensor([1, 2, 3, 1, 2])\n        >>> counts\n        tensor([2, 2, 1, 2, 1])\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return handle_torch_function(\n                unique_consecutive, (input,), input, return_inverse=return_inverse,\n                return_counts=return_counts, dim=dim)\n    output, inverse_indices, counts = _VF.unique_consecutive(  # type: ignore\n        input, return_inverse=return_inverse, return_counts=return_counts, dim=dim)\n    return output, inverse_indices, counts\n\n\ndef _return_counts(input, sorted=True, return_inverse=False, return_counts=False, dim=None):\n    # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]\n\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return _unique_impl(input, sorted, return_inverse, return_counts, dim)\n\n    output, _, counts = _unique_impl(input, sorted, return_inverse, return_counts, dim)\n    return output, counts\n\n\ndef _return_output(input, sorted=True, return_inverse=False, return_counts=False, dim=None):\n    # type: (Tensor, bool, bool, bool, Optional[int]) -> Tensor\n\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return _unique_impl(input, sorted, return_inverse, return_counts, dim)\n\n    output, _, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)\n    return output\n\n\ndef _return_inverse(input, sorted=True, return_inverse=False, return_counts=False, dim=None):\n    # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]\n\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return _unique_impl(input, sorted, return_inverse, return_counts, dim)\n\n    output, inverse_indices, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)\n    return output, inverse_indices\n\n\n_return_inverse_false = boolean_dispatch(\n    arg_name='return_counts',\n    arg_index=3,\n    default=False,\n    if_true=_return_counts,\n    if_false=_return_output,\n    module_name=__name__,\n    func_name='unique')\n\n_return_inverse_true = boolean_dispatch(\n    arg_name='return_counts',\n    arg_index=3,\n    default=False,\n    if_true=_unique_impl,\n    if_false=_return_inverse,\n    module_name=__name__,\n    func_name='unique')\n\n# The return type of unique depends on `return_inverse`, and `return_counts` so in order to\n# resolve the output type in TorchScript we need to statically know the value of both parameters\n\nunique = boolean_dispatch(\n    arg_name='return_inverse',\n    arg_index=2,\n    default=False,\n    if_true=_return_inverse_true,\n    if_false=_return_inverse_false,\n    module_name=__name__,\n    func_name='unique')\nunique.__doc__ = _unique_impl.__doc__\n\n\ndef _consecutive_return_counts(input, return_inverse=False, return_counts=False, dim=None):\n    # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]\n\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return _unique_consecutive_impl(input, return_inverse, return_counts, dim)\n\n    output, _, counts = _unique_consecutive_impl(input, return_inverse, return_counts, dim)\n    return output, counts\n\n\ndef _consecutive_return_output(input, return_inverse=False, return_counts=False, dim=None):\n    # type: (Tensor, bool, bool, Optional[int]) -> Tensor\n\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return _unique_consecutive_impl(input, return_inverse, return_counts, dim)\n\n    output, _, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)\n    return output\n\n\ndef _consecutive_return_inverse(input, return_inverse=False, return_counts=False, dim=None):\n    # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]\n\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return _unique_consecutive_impl(input, return_inverse, return_counts, dim)\n\n    output, inverse_indices, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)\n    return output, inverse_indices\n\n\n_consecutive_return_inverse_false = boolean_dispatch(\n    arg_name='return_counts',\n    arg_index=1,\n    default=False,\n    if_true=_consecutive_return_counts,\n    if_false=_consecutive_return_output,\n    module_name=__name__,\n    func_name='unique_consecutive')\n\n_consecutive_return_inverse_true = boolean_dispatch(\n    arg_name='return_counts',\n    arg_index=1,\n    default=False,\n    if_true=_unique_consecutive_impl,\n    if_false=_consecutive_return_inverse,\n    module_name=__name__,\n    func_name='unique_consecutive')\n\n# The return type of unique depends on `return_inverse`, and `return_counts` so in order to\n# resolve the output type in TorchScript we need to statically know the value of both parameters\n\nunique_consecutive = boolean_dispatch(\n    arg_name='return_inverse',\n    arg_index=2,\n    default=False,\n    if_true=_consecutive_return_inverse_true,\n    if_false=_consecutive_return_inverse_false,\n    module_name=__name__,\n    func_name='unique_consecutive')\nunique_consecutive.__doc__ = _unique_consecutive_impl.__doc__\n\n\ndef tensordot(a, b, dims=2):\n    r\"\"\"Returns a contraction of a and b over multiple dimensions.\n\n    :attr:`tensordot` implements a generalized matrix product.\n\n    Args:\n      a (Tensor): Left tensor to contract\n      b (Tensor): Right tensor to contract\n      dims (int or tuple of two lists of integers): number of dimensions to\n         contract or explicit lists of dimensions for :attr:`a` and\n         :attr:`b` respectively\n\n    When called with a non-negative integer argument :attr:`dims` = :math:`d`, and\n    the number of dimensions of :attr:`a` and :attr:`b` is :math:`m` and :math:`n`,\n    respectively, :func:`~torch.tensordot` computes\n\n    .. math::\n        r_{i_0,...,i_{m-d}, i_d,...,i_n}\n          = \\sum_{k_0,...,k_{d-1}} a_{i_0,...,i_{m-d},k_0,...,k_{d-1}} \\times b_{k_0,...,k_{d-1}, i_d,...,i_n}.\n\n    When called with :attr:`dims` of the list form, the given dimensions will be contracted\n    in place of the last :math:`d` of :attr:`a` and the first :math:`d` of :math:`b`. The sizes\n    in these dimensions must match, but :func:`~torch.tensordot` will deal with broadcasted\n    dimensions.\n\n    Examples::\n\n        >>> a = torch.arange(60.).reshape(3, 4, 5)\n        >>> b = torch.arange(24.).reshape(4, 3, 2)\n        >>> torch.tensordot(a, b, dims=([1, 0], [0, 1]))\n        tensor([[4400., 4730.],\n                [4532., 4874.],\n                [4664., 5018.],\n                [4796., 5162.],\n                [4928., 5306.]])\n\n        >>> a = torch.randn(3, 4, 5, device='cuda')\n        >>> b = torch.randn(4, 5, 6, device='cuda')\n        >>> c = torch.tensordot(a, b, dims=2).cpu()\n        tensor([[ 8.3504, -2.5436,  6.2922,  2.7556, -1.0732,  3.2741],\n                [ 3.3161,  0.0704,  5.0187, -0.4079, -4.3126,  4.8744],\n                [ 0.8223,  3.9445,  3.2168, -0.2400,  3.4117,  1.7780]])\n\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if (type(a) is not Tensor or type(b) is not Tensor) and has_torch_function((a, b)):\n            return handle_torch_function(tensordot, (a, b), a, b, dims=dims)\n    if isinstance(dims, (list, tuple)) or \\\n       (isinstance(dims, torch.Tensor) and dims.numel() > 1):\n        dims_a, dims_b = dims\n    else:\n        if isinstance(dims, torch.Tensor):\n            dims = dims.item()\n        if dims < 0:\n            raise RuntimeError(f\"tensordot expects dims >= 0, but got dims={dims}\")\n        dims_a = list(range(-dims, 0))\n        dims_b = list(range(dims))\n    return _VF.tensordot(a, b, dims_a, dims_b)  # type: ignore\n\ndef cartesian_prod(*tensors):\n    \"\"\"Do cartesian product of the given sequence of tensors. The behavior is similar to\n    python's `itertools.product`.\n\n    Arguments:\n        *tensors: any number of 1 dimensional tensors.\n\n    Returns:\n        Tensor: A tensor equivalent to converting all the input tensors into lists,\n            do `itertools.product` on these lists, and finally convert the resulting list\n            into tensor.\n\n    Example::\n\n        >>> a = [1, 2, 3]\n        >>> b = [4, 5]\n        >>> list(itertools.product(a, b))\n        [(1, 4), (1, 5), (2, 4), (2, 5), (3, 4), (3, 5)]\n        >>> tensor_a = torch.tensor(a)\n        >>> tensor_b = torch.tensor(b)\n        >>> torch.cartesian_prod(tensor_a, tensor_b)\n        tensor([[1, 4],\n                [1, 5],\n                [2, 4],\n                [2, 5],\n                [3, 4],\n                [3, 5]])\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):\n            return handle_torch_function(cartesian_prod, tensors, *tensors)\n    return _VF.cartesian_prod(tensors)  # type: ignore\n\ndef block_diag(*tensors):\n    \"\"\"Create a block diagonal matrix from provided tensors.\n\n    Arguments:\n        *tensors: One or more tensors with 0, 1, or 2 dimensions.\n\n    Returns:\n        Tensor: A 2 dimensional tensor with all the input tensors arranged in\n            order such that their upper left and lower right corners are\n            diagonally adjacent. All other elements are set to 0.\n\n    Example::\n\n        >>> import torch\n        >>> A = torch.tensor([[0, 1], [1, 0]])\n        >>> B = torch.tensor([[3, 4, 5], [6, 7, 8]])\n        >>> C = torch.tensor(7)\n        >>> D = torch.tensor([1, 2, 3])\n        >>> E = torch.tensor([[4], [5], [6]])\n        >>> torch.block_diag(A, B, C, D, E)\n        tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],\n                [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n                [0, 0, 3, 4, 5, 0, 0, 0, 0, 0],\n                [0, 0, 6, 7, 8, 0, 0, 0, 0, 0],\n                [0, 0, 0, 0, 0, 7, 0, 0, 0, 0],\n                [0, 0, 0, 0, 0, 0, 1, 2, 3, 0],\n                [0, 0, 0, 0, 0, 0, 0, 0, 0, 4],\n                [0, 0, 0, 0, 0, 0, 0, 0, 0, 5],\n                [0, 0, 0, 0, 0, 0, 0, 0, 0, 6]])\n    \"\"\"\n    if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):\n        return handle_torch_function(block_diag, tensors, *tensors)\n    return torch._C._VariableFunctions.block_diag(tensors)  # type: ignore\n\n\ndef cdist(x1, x2, p=2., compute_mode='use_mm_for_euclid_dist_if_necessary'):\n    # type: (Tensor, Tensor, float, str) -> (Tensor)\n    r\"\"\"Computes batched the p-norm distance between each pair of the two collections of row vectors.\n\n    Args:\n        x1 (Tensor): input tensor of shape :math:`B \\times P \\times M`.\n        x2 (Tensor): input tensor of shape :math:`B \\times R \\times M`.\n        p: p value for the p-norm distance to calculate between each vector pair\n            :math:`\\in [0, \\infty]`.\n        compute_mode:\n            'use_mm_for_euclid_dist_if_necessary' - will use matrix multiplication approach to calculate\n            euclidean distance (p = 2) if P > 25 or R > 25\n            'use_mm_for_euclid_dist' - will always use matrix multiplication approach to calculate\n            euclidean distance (p = 2)\n            'donot_use_mm_for_euclid_dist' - will never use matrix multiplication approach to calculate\n            euclidean distance (p = 2)\n            Default: use_mm_for_euclid_dist_if_necessary.\n\n    If x1 has shape :math:`B \\times P \\times M` and x2 has shape :math:`B \\times R \\times M` then the\n    output will have shape :math:`B \\times P \\times R`.\n\n    This function is equivalent to `scipy.spatial.distance.cdist(input,'minkowski', p=p)`\n    if :math:`p \\in (0, \\infty)`. When :math:`p = 0` it is equivalent to\n    `scipy.spatial.distance.cdist(input, 'hamming') * M`. When :math:`p = \\infty`, the closest\n    scipy function is `scipy.spatial.distance.cdist(xn, lambda x, y: np.abs(x - y).max())`.\n\n    Example:\n\n        >>> a = torch.tensor([[0.9041,  0.0196], [-0.3108, -2.4423], [-0.4821,  1.059]])\n        >>> a\n        tensor([[ 0.9041,  0.0196],\n                [-0.3108, -2.4423],\n                [-0.4821,  1.0590]])\n        >>> b = torch.tensor([[-2.1763, -0.4713], [-0.6986,  1.3702]])\n        >>> b\n        tensor([[-2.1763, -0.4713],\n                [-0.6986,  1.3702]])\n        >>> torch.cdist(a, b, p=2)\n        tensor([[3.1193, 2.0959],\n                [2.7138, 3.8322],\n                [2.2830, 0.3791]])\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if (type(x1) is not Tensor or type(x2) is not Tensor) and has_torch_function((x1, x2)):\n            return handle_torch_function(\n                cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode)\n    if compute_mode == 'use_mm_for_euclid_dist_if_necessary':\n        return _VF.cdist(x1, x2, p, None)  # type: ignore\n    elif compute_mode == 'use_mm_for_euclid_dist':\n        return _VF.cdist(x1, x2, p, 1)  # type: ignore\n    elif compute_mode == 'donot_use_mm_for_euclid_dist':\n        return _VF.cdist(x1, x2, p, 2)  # type: ignore\n    else:\n        raise ValueError(f\"{compute_mode} is not a valid value for compute_mode\")\n\ndef atleast_1d(*tensors):\n    r\"\"\"\n    Returns a 1-dimensional view of each input tensor with zero dimensions.\n    Input tensors with one or more dimensions are returned as-is.\n\n    Args:\n        input (Tensor or list of Tensors)\n\n    Returns:\n        output (Tensor or tuple of Tensors)\n\n    Example::\n        >>> x = torch.randn(2)\n        >>> x\n        tensor([1.4584, 0.7583])\n        >>> torch.atleast_1d(x)\n        tensor([1.4584, 0.7583])\n        >>> x = torch.tensor(1.)\n        >>> x\n        tensor(1.)\n        >>> torch.atleast_1d(x)\n        tensor([1.])\n        >>> x = torch.tensor(0.5)\n        >>> y = torch.tensor(1.)\n        >>> torch.atleast_1d((x,y))\n        (tensor([0.5000]), tensor([1.]))\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):\n            return handle_torch_function(atleast_1d, tensors, *tensors)\n    if len(tensors) == 1:\n        tensors = tensors[0]\n    return _VF.atleast_1d(tensors)  # type: ignore\n\ndef atleast_2d(*tensors):\n    r\"\"\"\n    Returns a 2-dimensional view of each each input tensor with zero dimensions.\n    Input tensors with two or more dimensions are returned as-is.\n    Args:\n        input (Tensor or list of Tensors)\n\n    Returns:\n        output (Tensor or tuple of Tensors)\n\n    Example::\n        >>> x = torch.tensor(1.)\n        >>> x\n        tensor(1.)\n        >>> torch.atleast_2d(x)\n        tensor([[1.]])\n        >>> x = torch.randn(2,2)\n        >>> x\n        tensor([[2.2086, 2.5165],\n                [0.1757, 0.5194]])\n        >>> torch.atleast_2d(x)\n        tensor([[2.2086, 2.5165],\n                [0.1757, 0.5194]])\n        >>> x = torch.tensor(0.5)\n        >>> y = torch.tensor(1.)\n        >>> torch.atleast_2d((x,y))\n        (tensor([[0.5000]]), tensor([[1.]]))\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):\n            return handle_torch_function(atleast_2d, tensors, *tensors)\n    if len(tensors) == 1:\n        tensors = tensors[0]\n    return _VF.atleast_2d(tensors)  # type: ignore\n\ndef atleast_3d(*tensors):\n    r\"\"\"\n    Returns a 3-dimensional view of each each input tensor with zero dimensions.\n    Input tensors with three or more dimensions are returned as-is.\n    Args:\n        input (Tensor or list of Tensors)\n\n    Returns:\n        output (Tensor or tuple of Tensors)\n\n    Example:\n\n        >>> x = torch.tensor(0.5)\n        >>> x\n        tensor(0.5000)\n        >>> torch.atleast_3d(x)\n        tensor([[[0.5000]]])\n        >>> y = torch.randn(2,2)\n        >>> y\n        tensor([[-0.8079,  0.7460],\n                [-1.1647,  1.4734]])\n        >>> torch.atleast_3d(y)\n        tensor([[[-0.8079],\n                [ 0.7460]],\n                <BLANKLINE>\n                [[-1.1647],\n                [ 1.4734]]])\n        >>> x = torch.randn(1,1,1)\n        >>> x\n        tensor([[[-1.5689]]])\n        >>> torch.atleast_3d(x)\n        tensor([[[-1.5689]]])\n        >>> x = torch.tensor(0.5)\n        >>> y = torch.tensor(1.)\n        >>> torch.atleast_3d((x,y))\n        (tensor([[[0.5000]]]), tensor([[[1.]]]))\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):\n            return handle_torch_function(atleast_3d, tensors, *tensors)\n    if len(tensors) == 1:\n        tensors = tensors[0]\n    return _VF.atleast_3d(tensors)  # type: ignore\n\n\nif TYPE_CHECKING:\n    pass\n    # There's no good way to use this type annotation; cannot rename norm() to\n    # _norm_impl() in a way that doesn't break JIT overloads. So leave untyped\n    # for mypy for now.\n    #    def norm(input: Tensor,\n    #             p: Optional[Union[str, Number]] = \"fro\",\n    #             dim: Optional[Union[int, List[int]]] = None,\n    #             keepdim: bool = False,\n    #             out: Optional[Tensor] = None,\n    #             dtype: _dtype = None) -> Tensor:\n    #        return _norm_impl(input, p, dim, keepdim, out, dtype)\nelse:\n    # TODO: type dim as BroadcastingList when\n    # https://github.com/pytorch/pytorch/issues/33782 is fixed\n    @overload  # noqa: 749\n    def norm(input, p=\"fro\", dim=None, keepdim=False, out=None, dtype=None):  # noqa: 749\n        # type: (Tensor, str, Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor\n        pass\n\n    @overload  # noqa: 749\n    def norm(input, p=\"fro\", dim=None, keepdim=False, out=None, dtype=None):  # noqa: 749\n        # type: (Tensor, Optional[number], Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor\n        pass\n\n    @overload  # noqa: 749\n    def norm(input, p=\"fro\", dim=None, keepdim=False, out=None, dtype=None):  # noqa: 749\n        # type: (Tensor, Optional[number], Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor\n        pass\n\n    @overload  # noqa: 749\n    def norm(input, p=\"fro\", dim=None, keepdim=False, out=None, dtype=None):  # noqa: 749\n        # type: (Tensor, str, Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor\n        pass\n\n\ndef norm(input, p=\"fro\", dim=None, keepdim=False, out=None, dtype=None):  # noqa: 749\n    r\"\"\"Returns the matrix norm or vector norm of a given tensor.\n\n    .. warning::\n\n        torch.norm is deprecated and may be removed in a future PyTorch release.\n        Use :func:`torch.linalg.norm` instead, but note that :func:`torch.linalg.norm`\n        has a different signature and slightly different behavior that is\n        more consistent with NumPy's numpy.linalg.norm.\n\n    Args:\n        input (Tensor): the input tensor\n        p (int, float, inf, -inf, 'fro', 'nuc', optional): the order of norm. Default: ``'fro'``\n            The following norms can be calculated:\n\n            =====  ============================  ==========================\n            ord    matrix norm                   vector norm\n            =====  ============================  ==========================\n            None   Frobenius norm                2-norm\n            'fro'  Frobenius norm                --\n            'nuc'  nuclear norm                  --\n            Other  as vec norm when dim is None  sum(abs(x)**ord)**(1./ord)\n            =====  ============================  ==========================\n\n        dim (int, 2-tuple of ints, 2-list of ints, optional): If it is an int,\n            vector norm will be calculated, if it is 2-tuple of ints, matrix norm\n            will be calculated. If the value is None, matrix norm will be calculated\n            when the input tensor only has two dimensions, vector norm will be\n            calculated when the input tensor only has one dimension. If the input\n            tensor has more than two dimensions, the vector norm will be applied to\n            last dimension.\n        keepdim (bool, optional): whether the output tensors have :attr:`dim`\n            retained or not. Ignored if :attr:`dim` = ``None`` and\n            :attr:`out` = ``None``. Default: ``False``\n        out (Tensor, optional): the output tensor. Ignored if\n            :attr:`dim` = ``None`` and :attr:`out` = ``None``.\n        dtype (:class:`torch.dtype`, optional): the desired data type of\n            returned tensor. If specified, the input tensor is casted to\n            :attr:'dtype' while performing the operation. Default: None.\n\n\n    Example::\n\n        >>> import torch\n        >>> a = torch.arange(9, dtype= torch.float) - 4\n        >>> b = a.reshape((3, 3))\n        >>> torch.norm(a)\n        tensor(7.7460)\n        >>> torch.norm(b)\n        tensor(7.7460)\n        >>> torch.norm(a, float('inf'))\n        tensor(4.)\n        >>> torch.norm(b, float('inf'))\n        tensor(4.)\n        >>> c = torch.tensor([[ 1, 2, 3],[-1, 1, 4]] , dtype= torch.float)\n        >>> torch.norm(c, dim=0)\n        tensor([1.4142, 2.2361, 5.0000])\n        >>> torch.norm(c, dim=1)\n        tensor([3.7417, 4.2426])\n        >>> torch.norm(c, p=1, dim=1)\n        tensor([6., 6.])\n        >>> d = torch.arange(8, dtype= torch.float).reshape(2,2,2)\n        >>> torch.norm(d, dim=(1,2))\n        tensor([ 3.7417, 11.2250])\n        >>> torch.norm(d[0, :, :]), torch.norm(d[1, :, :])\n        (tensor(3.7417), tensor(11.2250))\n    \"\"\"\n\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return handle_torch_function(\n                norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype)\n\n    ndim = input.dim()\n\n    # catch default case\n    if dim is None and out is None and dtype is None and p is not None:\n        if isinstance(p, str):\n            if p == \"fro\":\n                return _VF.frobenius_norm(input, dim=(), keepdim=keepdim)  # type: ignore\n        if not isinstance(p, str):\n            _dim = [i for i in range(ndim)]  # noqa: C416 TODO: rewrite as list(range(m))\n            return _VF.norm(input, p, dim=_dim, keepdim=keepdim)  # type: ignore\n\n    # TODO: when https://github.com/pytorch/pytorch/issues/33782 is fixed\n    # remove the overloads where dim is an int and replace with BraodcastingList1\n    # and remove next four lines, replace _dim with dim\n    if dim is not None:\n        if isinstance(dim, int):\n            _dim = [dim]\n        else:\n            _dim = dim\n    else:\n        _dim = None  # type: ignore\n\n    if isinstance(p, str):\n        if p == \"fro\":\n            if dtype is not None:\n                raise ValueError(\"dtype argument is not supported in frobenius norm\")\n\n            if _dim is None:\n                _dim = [i for i in range(ndim)]  # noqa: C416 TODO: rewrite as list(range(m))\n            if out is None:\n                return _VF.frobenius_norm(input, _dim, keepdim=keepdim)  # type: ignore\n            else:\n                return _VF.frobenius_norm(input, _dim, keepdim=keepdim, out=out)  # type: ignore\n        elif p == \"nuc\":\n            if dtype is not None:\n                raise ValueError(\"dtype argument is not supported in nuclear norm\")\n            if _dim is None:\n                if out is None:\n                    return _VF.nuclear_norm(input, keepdim=keepdim)  # type: ignore\n                else:\n                    return _VF.nuclear_norm(input, keepdim=keepdim, out=out)  # type: ignore\n            else:\n                if out is None:\n                    return _VF.nuclear_norm(input, _dim, keepdim=keepdim)  # type: ignore\n                else:\n                    return _VF.nuclear_norm(input, _dim, keepdim=keepdim, out=out)  # type: ignore\n        raise RuntimeError(f\"only valid string values are 'fro' and 'nuc', found {p}\")\n    else:\n        if _dim is None:\n            _dim = [i for i in range(ndim)]  # noqa: C416 TODO: rewrite as list(range(m))\n\n        if out is None:\n            if dtype is None:\n                return _VF.norm(input, p, _dim, keepdim=keepdim)  # type: ignore\n            else:\n                return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype)  # type: ignore\n        else:\n            if dtype is None:\n                return _VF.norm(input, p, _dim, keepdim=keepdim, out=out)  # type: ignore\n            else:\n                return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype, out=out)  # type: ignore\n\ndef chain_matmul(*matrices):\n    r\"\"\"Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed\n    using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms\n    of arithmetic operations (`[CLRS]`_). Note that since this is a function to compute the product, :math:`N`\n    needs to be greater than or equal to 2; if equal to 2 then a trivial matrix-matrix product is returned.\n    If :math:`N` is 1, then this is a no-op - the original matrix is returned as is.\n\n\n    Args:\n        matrices (Tensors...): a sequence of 2 or more 2-D tensors whose product is to be determined.\n\n\n    Returns:\n        Tensor: if the :math:`i^{th}` tensor was of dimensions :math:`p_{i} \\times p_{i + 1}`, then the product\n        would be of dimensions :math:`p_{1} \\times p_{N + 1}`.\n\n    Example::\n\n        >>> a = torch.randn(3, 4)\n        >>> b = torch.randn(4, 5)\n        >>> c = torch.randn(5, 6)\n        >>> d = torch.randn(6, 7)\n        >>> torch.chain_matmul(a, b, c, d)\n        tensor([[ -2.3375,  -3.9790,  -4.1119,  -6.6577,   9.5609, -11.5095,  -3.2614],\n                [ 21.4038,   3.3378,  -8.4982,  -5.2457, -10.2561,  -2.4684,   2.7163],\n                [ -0.9647,  -5.8917,  -2.3213,  -5.2284,  12.8615, -12.2816,  -2.5095]])\n\n    .. _`[CLRS]`: https://mitpress.mit.edu/books/introduction-algorithms-third-edition\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if any(type(t) is not Tensor for t in matrices) and has_torch_function(matrices):\n            return handle_torch_function(chain_matmul, matrices, *matrices)\n    return _VF.chain_matmul(matrices)  # type: ignore\n\n\ndef _lu_impl(A, pivot=True, get_infos=False, out=None):\n    # type: (Tensor, bool, bool, Any) -> Tuple[Tensor, Tensor, Tensor]\n    r\"\"\"Computes the LU factorization of a matrix or batches of matrices\n    :attr:`A`. Returns a tuple containing the LU factorization and\n    pivots of :attr:`A`.  Pivoting is done if :attr:`pivot` is set to\n    ``True``.\n\n    .. note::\n        The pivots returned by the function are 1-indexed. If :attr:`pivot` is ``False``,\n        then the returned pivots is a tensor filled with zeros of the appropriate size.\n\n    .. note::\n        LU factorization with :attr:`pivot` = ``False`` is not available for CPU, and attempting\n        to do so will throw an error. However, LU factorization with :attr:`pivot` = ``False`` is\n        available for CUDA.\n\n    .. note::\n        This function does not check if the factorization was successful or not if\n        :attr:`get_infos` is ``True`` since the status of the factorization is present in the\n        third element of the return tuple.\n\n    .. note::\n        In the case of batches of square matrices with size less or\n        equal to 32 on a CUDA device, the LU factorization is repeated\n        for singular matrices due to the bug in the MAGMA library (see\n        magma issue 13).\n\n    .. note::\n       ``L``, ``U``, and ``P`` can be derived using :func:`torch.lu_unpack`.\n\n    Arguments:\n        A (Tensor): the tensor to factor of size :math:`(*, m, n)`\n        pivot (bool, optional): controls whether pivoting is done. Default: ``True``\n        get_infos (bool, optional): if set to ``True``, returns an info IntTensor.\n                                    Default: ``False``\n        out (tuple, optional): optional output tuple. If :attr:`get_infos` is ``True``,\n                               then the elements in the tuple are Tensor, IntTensor,\n                               and IntTensor. If :attr:`get_infos` is ``False``, then the\n                               elements in the tuple are Tensor, IntTensor. Default: ``None``\n\n    Returns:\n        (Tensor, IntTensor, IntTensor (optional)): A tuple of tensors containing\n\n            - **factorization** (*Tensor*): the factorization of size :math:`(*, m, n)`\n\n            - **pivots** (*IntTensor*): the pivots of size :math:`(*, m)`\n\n            - **infos** (*IntTensor*, *optional*): if :attr:`get_infos` is ``True``, this is a tensor of\n              size :math:`(*)` where non-zero values indicate whether factorization for the matrix or\n              each minibatch has succeeded or failed\n\n    Example::\n\n        >>> A = torch.randn(2, 3, 3)\n        >>> A_LU, pivots = torch.lu(A)\n        >>> A_LU\n        tensor([[[ 1.3506,  2.5558, -0.0816],\n                 [ 0.1684,  1.1551,  0.1940],\n                 [ 0.1193,  0.6189, -0.5497]],\n\n                [[ 0.4526,  1.2526, -0.3285],\n                 [-0.7988,  0.7175, -0.9701],\n                 [ 0.2634, -0.9255, -0.3459]]])\n        >>> pivots\n        tensor([[ 3,  3,  3],\n                [ 3,  3,  3]], dtype=torch.int32)\n        >>> A_LU, pivots, info = torch.lu(A, get_infos=True)\n        >>> if info.nonzero().size(0) == 0:\n        ...   print('LU factorization succeeded for all samples!')\n        LU factorization succeeded for all samples!\n    \"\"\"\n    # If get_infos is True, then we don't need to check for errors and vice versa\n    return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))\n\n\nif TYPE_CHECKING:\n    _ListOrSeq = Sequence[Tensor]\nelse:\n    _ListOrSeq = List[Tensor]\n\ndef _check_list_size(out_len: int, get_infos: bool, out: _ListOrSeq) -> None:\n    get_infos_int = 1 if get_infos else 0\n    if out_len - get_infos_int != 2:\n        raise TypeError(f\"expected tuple of {2 + int(get_infos)} elements but got {out_len}\")\n    if not isinstance(out, (tuple, list)):\n        raise TypeError(f\"argument 'out' must be tuple of Tensors, not {type(out).__name__}\")\n\ndef _lu_with_infos(A, pivot=True, get_infos=False, out=None):\n    # type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor]\n    if not torch.jit.is_scripting():\n        if type(A) is not Tensor and has_torch_function((A,)):\n            return handle_torch_function(\n                lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)\n    result = _lu_impl(A, pivot, get_infos, out)\n    if out is not None:\n        _check_list_size(len(out), get_infos, out)\n        for i in range(len(out)):\n            out[i].resize_as_(result[i]).copy_(result[i])\n        return out\n    else:\n        return result  # A_LU, pivots, infos\n\ndef _lu_no_infos(A, pivot=True, get_infos=False, out=None):\n    # type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]\n    # need to check for torch_function here so that we exit if\n    if not torch.jit.is_scripting():\n        if type(A) is not Tensor and has_torch_function((A,)):\n            return handle_torch_function(\n                lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)\n    result = _lu_impl(A, pivot, get_infos, out)\n    if out is not None:\n        _check_list_size(len(out), get_infos, out)\n        for i in range(len(out)):\n            out[i].resize_as_(result[i]).copy_(result[i])\n        return out\n    else:\n        return result[0], result[1]  # A_LU, pivots\n\n# The return type of lu depends on `get_infos`, so in order to resolve the output type\n# of lu in TorchScript we need to statically know the value of `get_infos`\nlu = boolean_dispatch(\n    arg_name='get_infos',\n    arg_index=2,\n    default=False,\n    if_true=_lu_with_infos,\n    if_false=_lu_no_infos,\n    module_name=__name__,\n    func_name='lu')\nlu.__doc__ = _lu_impl.__doc__\n\ndef align_tensors(*tensors):\n    raise RuntimeError('`align_tensors` not yet implemented.')\n"
  },
  {
    "path": "patches/pytorch/1.7.0/functional.py",
    "content": "from typing import (\n    Tuple, Optional, Union, Any, Sequence, TYPE_CHECKING\n)\n\nimport librosa  # STFT patch for aarch64\nimport numpy as np\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.types import _size\nfrom ._lowrank import svd_lowrank, pca_lowrank\nfrom .overrides import has_torch_function, handle_torch_function\nfrom ._jit_internal import boolean_dispatch, List\nfrom ._jit_internal import _overload as overload\n\nTensor = torch.Tensor\nfrom torch import _VF\n\n__all__ = [\n    'atleast_1d',\n    'atleast_2d',\n    'atleast_3d',\n    'align_tensors',\n    'broadcast_tensors',\n    'cartesian_prod',\n    'block_diag',\n    'cdist',\n    'chain_matmul',\n    'einsum',\n    'istft',\n    'lu',\n    'lu_unpack',\n    'norm',\n    'meshgrid',\n    'pca_lowrank',\n    'split',\n    'stft',\n    'svd_lowrank',\n    'tensordot',\n    'unique',\n    'unique_consecutive',\n]\n\n\ndef broadcast_tensors(*tensors):\n    r\"\"\"broadcast_tensors(*tensors) -> List of Tensors\n\n    Broadcasts the given tensors according to :ref:`broadcasting-semantics`.\n\n    Args:\n        *tensors: any number of tensors of the same type\n\n    .. warning::\n\n        More than one element of a broadcasted tensor may refer to a single\n        memory location. As a result, in-place operations (especially ones that\n        are vectorized) may result in incorrect behavior. If you need to write\n        to the tensors, please clone them first.\n\n    Example::\n\n        >>> x = torch.arange(3).view(1, 3)\n        >>> y = torch.arange(2).view(2, 1)\n        >>> a, b = torch.broadcast_tensors(x, y)\n        >>> a.size()\n        torch.Size([2, 3])\n        >>> a\n        tensor([[0, 1, 2],\n                [0, 1, 2]])\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):\n            return handle_torch_function(broadcast_tensors, tensors, *tensors)\n    return _VF.broadcast_tensors(tensors)  # type: ignore\n\n\ndef split(tensor, split_size_or_sections, dim=0):\n    r\"\"\"Splits the tensor into chunks. Each chunk is a view of the original tensor.\n\n    If :attr:`split_size_or_sections` is an integer type, then :attr:`tensor` will\n    be split into equally sized chunks (if possible). Last chunk will be smaller if\n    the tensor size along the given dimension :attr:`dim` is not divisible by\n    :attr:`split_size`.\n\n    If :attr:`split_size_or_sections` is a list, then :attr:`tensor` will be split\n    into ``len(split_size_or_sections)`` chunks with sizes in :attr:`dim` according\n    to :attr:`split_size_or_sections`.\n\n    Arguments:\n        tensor (Tensor): tensor to split.\n        split_size_or_sections (int) or (list(int)): size of a single chunk or\n            list of sizes for each chunk\n        dim (int): dimension along which to split the tensor.\n\n    Example::\n        >>> a = torch.arange(10).reshape(5,2)\n        >>> a\n        tensor([[0, 1],\n                [2, 3],\n                [4, 5],\n                [6, 7],\n                [8, 9]])\n        >>> torch.split(a, 2)\n        (tensor([[0, 1],\n                 [2, 3]]),\n         tensor([[4, 5],\n                 [6, 7]]),\n         tensor([[8, 9]]))\n        >>> torch.split(a, [1,4])\n        (tensor([[0, 1]]),\n         tensor([[2, 3],\n                 [4, 5],\n                 [6, 7],\n                 [8, 9]]))\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if type(tensor) is not Tensor and has_torch_function((tensor,)):\n            return handle_torch_function(split, (tensor,), tensor, split_size_or_sections,\n                                         dim=dim)\n    # Overwriting reason:\n    # This dispatches to two ATen functions depending on the type of\n    # split_size_or_sections. The branching code is in tensor.py, which we\n    # call here.\n    return tensor.split(split_size_or_sections, dim)\n\n\nif TYPE_CHECKING:\n    _Indices = _size\nelse:\n    _Indices = List[int]\n\n\n# equivalent to itertools.product(indices)\ndef _indices_product(indices: _Indices) -> List[List[int]]:\n    empty_list = torch.jit.annotate(List[int], [])\n    result = [empty_list]\n    for idx in indices:\n        result_temp = torch.jit.annotate(List[List[int]], [])\n        for res in result:\n            for i in range(idx):\n                result_temp.append(res + [i])\n        result = result_temp\n    return result\n\n\ndef _index_tensor_with_indices_list(tensor, indices):\n    # type: (Tensor, List[int]) -> Tensor\n    out = tensor\n    for index in indices:\n        out = out[index]\n    return out\n\n\ndef lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):\n    # type: (Tensor, Tensor, bool, bool) ->  (Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]])\n    r\"\"\"Unpacks the data and pivots from a LU factorization of a tensor.\n\n    Returns a tuple of tensors as ``(the pivots, the L tensor, the U tensor)``.\n\n    Arguments:\n        LU_data (Tensor): the packed LU factorization data\n        LU_pivots (Tensor): the packed LU factorization pivots\n        unpack_data (bool): flag indicating if the data should be unpacked\n        unpack_pivots (bool): flag indicating if the pivots should be unpacked\n\n    Examples::\n\n        >>> A = torch.randn(2, 3, 3)\n        >>> A_LU, pivots = A.lu()\n        >>> P, A_L, A_U = torch.lu_unpack(A_LU, pivots)\n        >>>\n        >>> # can recover A from factorization\n        >>> A_ = torch.bmm(P, torch.bmm(A_L, A_U))\n\n        >>> # LU factorization of a rectangular matrix:\n        >>> A = torch.randn(2, 3, 2)\n        >>> A_LU, pivots = A.lu()\n        >>> P, A_L, A_U = torch.lu_unpack(A_LU, pivots)\n        >>> P\n        tensor([[[1., 0., 0.],\n                 [0., 1., 0.],\n                 [0., 0., 1.]],\n\n                [[0., 0., 1.],\n                 [0., 1., 0.],\n                 [1., 0., 0.]]])\n        >>> A_L\n        tensor([[[ 1.0000,  0.0000],\n                 [ 0.4763,  1.0000],\n                 [ 0.3683,  0.1135]],\n\n                [[ 1.0000,  0.0000],\n                 [ 0.2957,  1.0000],\n                 [-0.9668, -0.3335]]])\n        >>> A_U\n        tensor([[[ 2.1962,  1.0881],\n                 [ 0.0000, -0.8681]],\n\n                [[-1.0947,  0.3736],\n                 [ 0.0000,  0.5718]]])\n        >>> A_ = torch.bmm(P, torch.bmm(A_L, A_U))\n        >>> torch.norm(A_ - A)\n        tensor(2.9802e-08)\n    \"\"\"\n    if not torch.jit.is_scripting():\n        tens_ops = (LU_data, LU_pivots)\n        if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):\n            return handle_torch_function(\n                lu_unpack, tens_ops, LU_data, LU_pivots, unpack_data=unpack_data,\n                unpack_pivots=unpack_pivots)\n    shape = LU_data.shape\n    # In generalized LU factorization, the following shape relations hold:\n    #   A.shape[-2:] == (m, n)\n    #   P.shape[-2:] == (m, m)\n    #   L.shape[-2:] == (m, k)\n    #   U.shape[-2:] == (k, n)\n    # where k = min(m, n)\n    m, n = shape[-2:]\n    k = min(m, n)\n    if unpack_data:\n        U: Optional[Tensor] = LU_data.triu()\n        assert U is not None\n        if m != k:\n            U = U.narrow(-2, 0, k)\n        L: Optional[Tensor] = LU_data.tril()\n        assert L is not None\n        if k != n:\n            L = L.narrow(-1, 0, k)\n        L.diagonal(dim1=-2, dim2=-1).fill_(1)\n    else:\n        L = U = None\n\n    if unpack_pivots:\n        LU_pivots_zero_idx = LU_pivots - 1\n        if LU_data.dim() > 2:\n            P: Optional[Tensor] = torch.eye(m, device=LU_data.device,\n                                            dtype=LU_data.dtype) \\\n                .expand(shape[:-1] + (m,)) \\\n                .clone(memory_format=torch.contiguous_format)\n            assert P is not None\n\n            # TODO: rewrite when TorchScript supports product and map as\n            # product(*map(lambda x: list(range(x)), shape[:-2])) when issue 33781 is fixed\n            indices = _indices_product(shape[:-2])\n            for idx in indices:\n                final_order = [i for i in range(m)]  # noqa: C416 TODO: rewrite as list(range(m))\n                for k, j in enumerate(_index_tensor_with_indices_list(LU_pivots_zero_idx, idx)):\n                    final_order[k], final_order[j] = final_order[j], final_order[k]\n                # TODO: remove _index_tensor_with_indices_list when TorchScript supports indexing Tensor with list\n                p_idx = _index_tensor_with_indices_list(P, idx)\n                p_idx.copy_(p_idx.index_select(1, torch.as_tensor(final_order, device=LU_pivots.device)))\n        else:\n            P = torch.eye(m, device=LU_data.device, dtype=LU_data.dtype)\n            final_order = [i for i in range(m)]  # noqa: C416 TODO: rewrite as list(range(m))\n            for k, j, in enumerate(LU_pivots_zero_idx):\n                final_order[k], final_order[j] = final_order[j], final_order[k]\n            P = P.index_select(1, torch.as_tensor(final_order, device=LU_pivots.device))\n    else:\n        P = None\n\n    return P, L, U\n\n\ndef einsum(equation, *operands):\n    r\"\"\"einsum(equation, *operands) -> Tensor\n\nThis function provides a way of computing multilinear expressions (i.e. sums of products) using the\nEinstein summation convention.\n\nArgs:\n    equation (string): The equation is given in terms of lower case letters (indices) to be associated\n           with each dimension of the operands and result. The left hand side lists the operands\n           dimensions, separated by commas. There should be one index letter per tensor dimension.\n           The right hand side follows after `->` and gives the indices for the output.\n           If the `->` and right hand side are omitted, it implicitly defined as the alphabetically\n           sorted list of all indices appearing exactly once in the left hand side.\n           The indices not apprearing in the output are summed over after multiplying the operands\n           entries.\n           If an index appears several times for the same operand, a diagonal is taken.\n           Ellipses `...` represent a fixed number of dimensions. If the right hand side is inferred,\n           the ellipsis dimensions are at the beginning of the output.\n    operands (Tensor): The operands to compute the Einstein sum of.\n\n.. note::\n\n    This function does not optimize the given expression, so a different formula for the same computation may\n    run faster or consume less memory. Projects like opt_einsum (https://optimized-einsum.readthedocs.io/en/stable/)\n    can optimize the formula for you.\n\nExamples::\n\n    >>> x = torch.randn(5)\n    >>> y = torch.randn(4)\n    >>> torch.einsum('i,j->ij', x, y)  # outer product\n    tensor([[-0.0570, -0.0286, -0.0231,  0.0197],\n            [ 1.2616,  0.6335,  0.5113, -0.4351],\n            [ 1.4452,  0.7257,  0.5857, -0.4984],\n            [-0.4647, -0.2333, -0.1883,  0.1603],\n            [-1.1130, -0.5588, -0.4510,  0.3838]])\n\n\n    >>> A = torch.randn(3,5,4)\n    >>> l = torch.randn(2,5)\n    >>> r = torch.randn(2,4)\n    >>> torch.einsum('bn,anm,bm->ba', l, A, r) # compare torch.nn.functional.bilinear\n    tensor([[-0.3430, -5.2405,  0.4494],\n            [ 0.3311,  5.5201, -3.0356]])\n\n\n    >>> As = torch.randn(3,2,5)\n    >>> Bs = torch.randn(3,5,4)\n    >>> torch.einsum('bij,bjk->bik', As, Bs) # batch matrix multiplication\n    tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],\n             [-1.6706, -0.8097, -0.8025, -2.1183]],\n\n            [[ 4.2239,  0.3107, -0.5756, -0.2354],\n             [-1.4558, -0.3460,  1.5087, -0.8530]],\n\n            [[ 2.8153,  1.8787, -4.3839, -1.2112],\n             [ 0.3728, -2.1131,  0.0921,  0.8305]]])\n\n    >>> A = torch.randn(3, 3)\n    >>> torch.einsum('ii->i', A) # diagonal\n    tensor([-0.7825,  0.8291, -0.1936])\n\n    >>> A = torch.randn(4, 3, 3)\n    >>> torch.einsum('...ii->...i', A) # batch diagonal\n    tensor([[-1.0864,  0.7292,  0.0569],\n            [-0.9725, -1.0270,  0.6493],\n            [ 0.5832, -1.1716, -1.5084],\n            [ 0.4041, -1.1690,  0.8570]])\n\n    >>> A = torch.randn(2, 3, 4, 5)\n    >>> torch.einsum('...ij->...ji', A).shape # batch permute\n    torch.Size([2, 3, 5, 4])\n\"\"\"\n    if not torch.jit.is_scripting():\n        if any(type(t) is not Tensor for t in operands) and has_torch_function(operands):\n            return handle_torch_function(einsum, operands, equation, *operands)\n    if len(operands) == 1 and isinstance(operands[0], (list, tuple)):\n        # the old interface of passing the operands as one list argument\n        _operands = operands[0]\n        # recurse incase operands contains value that has torch function\n        # in the original implementation this line is omitted\n        return einsum(equation, *_operands)\n\n    return _VF.einsum(equation, operands)  # type: ignore\n\n\nif TYPE_CHECKING:\n    # The JIT doesn't understand Union, so only add type annotation for mypy\n    def meshgrid(*tensors: Union[Tensor, List[Tensor]]) -> Tuple[Tensor, ...]:\n        return _meshgrid(*tensors)\nelse:\n    def meshgrid(*tensors):\n        return _meshgrid(*tensors)\n\n\ndef _meshgrid(*tensors):\n    r\"\"\"Take :math:`N` tensors, each of which can be either scalar or 1-dimensional\nvector, and create :math:`N` N-dimensional grids, where the :math:`i` :sup:`th` grid is defined by\nexpanding the :math:`i` :sup:`th` input over dimensions defined by other inputs.\n\n\n    Args:\n        tensors (list of Tensor): list of scalars or 1 dimensional tensors. Scalars will be\n        treated as tensors of size :math:`(1,)` automatically\n\n    Returns:\n        seq (sequence of Tensors): If the input has :math:`k` tensors of size\n        :math:`(N_1,), (N_2,), \\ldots , (N_k,)`, then the output would also have :math:`k` tensors,\n        where all tensors are of size :math:`(N_1, N_2, \\ldots , N_k)`.\n\n    Example::\n\n        >>> x = torch.tensor([1, 2, 3])\n        >>> y = torch.tensor([4, 5, 6])\n        >>> grid_x, grid_y = torch.meshgrid(x, y)\n        >>> grid_x\n        tensor([[1, 1, 1],\n                [2, 2, 2],\n                [3, 3, 3]])\n        >>> grid_y\n        tensor([[4, 5, 6],\n                [4, 5, 6],\n                [4, 5, 6]])\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):\n            return handle_torch_function(meshgrid, tensors, *tensors)\n    if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)):\n        # the old interface of passing the operands as one list argument\n        tensors = tensors[0]  # type: ignore\n    return _VF.meshgrid(tensors)  # type: ignore\n\n\ndef stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,\n         win_length: Optional[int] = None, window: Optional[Tensor] = None,\n         center: bool = True, pad_mode: str = 'reflect', normalized: bool = False,\n         onesided: Optional[bool] = None,\n         return_complex: Optional[bool] = None) -> Tensor:\n    r\"\"\"Short-time Fourier transform (STFT).\n\n    .. warning::\n        Setting :attr:`return_complex` explicitly will be required in a future\n        PyTorch release. Set it to False to preserve the current behavior or\n        True to return a complex output.\n\n    The STFT computes the Fourier transform of short overlapping windows of the\n    input. This giving frequency components of the signal as they change over\n    time. The interface of this function is modeled after the librosa_ stft function.\n\n    .. _librosa: https://librosa.org/doc/latest/generated/librosa.stft.html\n\n    Ignoring the optional batch dimension, this method computes the following\n    expression:\n\n    .. math::\n        X[m, \\omega] = \\sum_{k = 0}^{\\text{win\\_length-1}}%\n                            \\text{window}[k]\\ \\text{input}[m \\times \\text{hop\\_length} + k]\\ %\n                            \\exp\\left(- j \\frac{2 \\pi \\cdot \\omega k}{\\text{win\\_length}}\\right),\n\n    where :math:`m` is the index of the sliding window, and :math:`\\omega` is\n    the frequency that :math:`0 \\leq \\omega < \\text{n\\_fft}`. When\n    :attr:`onesided` is the default value ``True``,\n\n    * :attr:`input` must be either a 1-D time sequence or a 2-D batch of time\n      sequences.\n\n    * If :attr:`hop_length` is ``None`` (default), it is treated as equal to\n      ``floor(n_fft / 4)``.\n\n    * If :attr:`win_length` is ``None`` (default), it is treated as equal to\n      :attr:`n_fft`.\n\n    * :attr:`window` can be a 1-D tensor of size :attr:`win_length`, e.g., from\n      :meth:`torch.hann_window`. If :attr:`window` is ``None`` (default), it is\n      treated as if having :math:`1` everywhere in the window. If\n      :math:`\\text{win\\_length} < \\text{n\\_fft}`, :attr:`window` will be padded on\n      both sides to length :attr:`n_fft` before being applied.\n\n    * If :attr:`center` is ``True`` (default), :attr:`input` will be padded on\n      both sides so that the :math:`t`-th frame is centered at time\n      :math:`t \\times \\text{hop\\_length}`. Otherwise, the :math:`t`-th frame\n      begins at time  :math:`t \\times \\text{hop\\_length}`.\n\n    * :attr:`pad_mode` determines the padding method used on :attr:`input` when\n      :attr:`center` is ``True``. See :meth:`torch.nn.functional.pad` for\n      all available options. Default is ``\"reflect\"``.\n\n    * If :attr:`onesided` is ``True`` (default for real input), only values for\n      :math:`\\omega` in :math:`\\left[0, 1, 2, \\dots, \\left\\lfloor\n      \\frac{\\text{n\\_fft}}{2} \\right\\rfloor + 1\\right]` are returned because\n      the real-to-complex Fourier transform satisfies the conjugate symmetry,\n      i.e., :math:`X[m, \\omega] = X[m, \\text{n\\_fft} - \\omega]^*`.\n      Note if the input or window tensors are complex, then :attr:`onesided`\n      output is not possible.\n\n    * If :attr:`normalized` is ``True`` (default is ``False``), the function\n      returns the normalized STFT results, i.e., multiplied by :math:`(\\text{frame\\_length})^{-0.5}`.\n\n    * If :attr:`return_complex` is ``True`` (default if input is complex), the\n      return is a ``input.dim() + 1`` dimensional complex tensor. If ``False``,\n      the output is a ``input.dim() + 2`` dimensional real tensor where the last\n      dimension represents the real and imaginary components.\n\n    Returns either a complex tensor of size :math:`(* \\times N \\times T)` if\n    :attr:`return_complex` is true, or a real tensor of size :math:`(* \\times N\n    \\times T \\times 2)`. Where :math:`*` is the optional batch size of\n    :attr:`input`, :math:`N` is the number of frequencies where STFT is applied\n    and :math:`T` is the total number of frames used.\n\n    .. warning::\n      This function changed signature at version 0.4.1. Calling with the\n      previous signature may cause error or return incorrect result.\n\n    Arguments:\n        input (Tensor): the input tensor\n        n_fft (int): size of Fourier transform\n        hop_length (int, optional): the distance between neighboring sliding window\n            frames. Default: ``None`` (treated as equal to ``floor(n_fft / 4)``)\n        win_length (int, optional): the size of window frame and STFT filter.\n            Default: ``None``  (treated as equal to :attr:`n_fft`)\n        window (Tensor, optional): the optional window function.\n            Default: ``None`` (treated as window of all :math:`1` s)\n        center (bool, optional): whether to pad :attr:`input` on both sides so\n            that the :math:`t`-th frame is centered at time :math:`t \\times \\text{hop\\_length}`.\n            Default: ``True``\n        pad_mode (string, optional): controls the padding method used when\n            :attr:`center` is ``True``. Default: ``\"reflect\"``\n        normalized (bool, optional): controls whether to return the normalized STFT results\n             Default: ``False``\n        onesided (bool, optional): controls whether to return half of results to\n            avoid redundancy for real inputs.\n            Default: ``True`` for real :attr:`input` and :attr:`window`, ``False`` otherwise.\n        return_complex (bool, optional): whether to return a complex tensor, or\n            a real tensor with an extra last dimension for the real and\n            imaginary components.\n\n    Returns:\n        Tensor: A tensor containing the STFT result with shape described above\n\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return handle_torch_function(\n                stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length,\n                window=window, center=center, pad_mode=pad_mode, normalized=normalized,\n                onesided=onesided, return_complex=return_complex)\n    # TODO: after having proper ways to map Python strings to ATen Enum, move\n    #       this and F.pad to ATen.\n    if center:\n        signal_dim = input.dim()\n        extended_shape = [1] * (3 - signal_dim) + list(input.size())\n        pad = int(n_fft // 2)\n        input = F.pad(input.view(extended_shape), (pad, pad), pad_mode)\n        input = input.view(input.shape[-signal_dim:])\n        \n    # STFT patch for aarch64\n    # https://stackoverflow.com/a/66872148\n    librosa_stft = librosa.stft(input.cpu().detach().numpy().reshape(-1), n_fft, hop_length, win_length, window=\"hann\", center=center, pad_mode=pad_mode)\n    librosa_stft = np.array([[a.real, a.imag] for a in librosa_stft])\n    librosa_stft = np.transpose(librosa_stft, axes=[0, 2, 1])\n    librosa_stft = np.expand_dims(librosa_stft, 0)\n    librosa_stft = torch.from_numpy(librosa_stft)\n    return librosa_stft\n    #return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore\n    #                normalized, onesided, return_complex)\n\ndef istft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,\n          win_length: Optional[int] = None, window: Optional[Tensor] = None,\n          center: bool = True, normalized: bool = False,\n          onesided: Optional[bool] = None, length: Optional[int] = None,\n          return_complex: bool = False) -> Tensor:\n    r\"\"\"Inverse short time Fourier Transform. This is expected to be the inverse of :func:`~torch.stft`.\n    It has the same parameters (+ additional optional parameter of :attr:`length`) and it should return the\n    least squares estimation of the original signal. The algorithm will check using the NOLA condition (\n    nonzero overlap).\n\n    Important consideration in the parameters :attr:`window` and :attr:`center` so that the envelop\n    created by the summation of all the windows is never zero at certain point in time. Specifically,\n    :math:`\\sum_{t=-\\infty}^{\\infty} |w|^2[n-t\\times hop\\_length] \\cancel{=} 0`.\n\n    Since :func:`~torch.stft` discards elements at the end of the signal if they do not fit in a frame,\n    ``istft`` may return a shorter signal than the original signal (can occur if :attr:`center` is False\n    since the signal isn't padded).\n\n    If :attr:`center` is ``True``, then there will be padding e.g. ``'constant'``, ``'reflect'``, etc.\n    Left padding can be trimmed off exactly because they can be calculated but right padding cannot be\n    calculated without additional information.\n\n    Example: Suppose the last window is:\n    ``[17, 18, 0, 0, 0]`` vs ``[18, 0, 0, 0, 0]``\n\n    The :attr:`n_fft`, :attr:`hop_length`, :attr:`win_length` are all the same which prevents the calculation\n    of right padding. These additional values could be zeros or a reflection of the signal so providing\n    :attr:`length` could be useful. If :attr:`length` is ``None`` then padding will be aggressively removed\n    (some loss of signal).\n\n    [1] D. W. Griffin and J. S. Lim, \"Signal estimation from modified short-time Fourier transform,\"\n    IEEE Trans. ASSP, vol.32, no.2, pp.236-243, Apr. 1984.\n\n    Arguments:\n        input (Tensor): The input tensor. Expected to be output of :func:`~torch.stft`,\n            can either be complex (``channel``, ``fft_size``, ``n_frame``), or real\n            (``channel``, ``fft_size``, ``n_frame``, 2) where the ``channel``\n            dimension is optional.\n        n_fft (int): Size of Fourier transform\n        hop_length (Optional[int]): The distance between neighboring sliding window frames.\n            (Default: ``n_fft // 4``)\n        win_length (Optional[int]): The size of window frame and STFT filter. (Default: ``n_fft``)\n        window (Optional[torch.Tensor]): The optional window function.\n            (Default: ``torch.ones(win_length)``)\n        center (bool): Whether :attr:`input` was padded on both sides so that the :math:`t`-th frame is\n            centered at time :math:`t \\times \\text{hop\\_length}`.\n            (Default: ``True``)\n        normalized (bool): Whether the STFT was normalized. (Default: ``False``)\n        onesided (Optional[bool]): Whether the STFT was onesided.\n            (Default: ``True`` if ``n_fft != fft_size`` in the input size)\n        length (Optional[int]): The amount to trim the signal by (i.e. the\n            original signal length). (Default: whole signal)\n        return_complex (Optional[bool]):\n            Whether the output should be complex, or if the input should be\n            assumed to derive from a real signal and window.\n            Note that this is incompatible with ``onesided=True``.\n            (Default: ``False``)\n\n    Returns:\n        Tensor: Least squares estimation of the original signal of size (..., signal_length)\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return handle_torch_function(\n                istft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length,\n                window=window, center=center, normalized=normalized, onesided=onesided,\n                length=length, return_complex=return_complex)\n\n    return _VF.istft(input, n_fft, hop_length, win_length, window, center,  # type: ignore\n                     normalized, onesided, length, return_complex)\n\n\ndel torch.unique_dim\n\n\nif TYPE_CHECKING:\n    # These _impl functions return a variable number of tensors as output with\n    # __torch_function__; tuple unpacking is done already rather than being\n    # done by the caller of the _impl function\n    _unique_impl_out = Any\nelse:\n    _unique_impl_out = Tuple[Tensor, Tensor, Tensor]\n\n\ndef _unique_impl(input: Tensor, sorted: bool = True,\n                 return_inverse: bool = False, return_counts: bool = False,\n                 dim: Optional[int] = None) -> _unique_impl_out:\n    r\"\"\"Returns the unique elements of the input tensor.\n\n    .. note:: This function is different from :func:`torch.unique_consecutive` in the sense that\n        this function also eliminates non-consecutive duplicate values.\n\n    .. note:: Currently in the CUDA implementation and the CPU implementation when dim is specified,\n        `torch.unique` always sort the tensor at the beginning regardless of the `sort` argument.\n        Sorting could be slow, so if your input tensor is already sorted, it is recommended to use\n        :func:`torch.unique_consecutive` which avoids the sorting.\n\n    Arguments:\n        input (Tensor): the input tensor\n        sorted (bool): Whether to sort the unique elements in ascending order\n            before returning as output.\n        return_inverse (bool): Whether to also return the indices for where\n            elements in the original input ended up in the returned unique list.\n        return_counts (bool): Whether to also return the counts for each unique\n            element.\n        dim (int): the dimension to apply unique. If ``None``, the unique of the\n            flattened input is returned. default: ``None``\n\n    Returns:\n        (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing\n\n            - **output** (*Tensor*): the output list of unique scalar elements.\n            - **inverse_indices** (*Tensor*): (optional) if\n              :attr:`return_inverse` is True, there will be an additional\n              returned tensor (same shape as input) representing the indices\n              for where elements in the original input map to in the output;\n              otherwise, this function will only return a single tensor.\n            - **counts** (*Tensor*): (optional) if\n              :attr:`return_counts` is True, there will be an additional\n              returned tensor (same shape as output or output.size(dim),\n              if dim was specified) representing the number of occurrences\n              for each unique value or tensor.\n\n    Example::\n\n        >>> output = torch.unique(torch.tensor([1, 3, 2, 3], dtype=torch.long))\n        >>> output\n        tensor([ 2,  3,  1])\n\n        >>> output, inverse_indices = torch.unique(\n                torch.tensor([1, 3, 2, 3], dtype=torch.long), sorted=True, return_inverse=True)\n        >>> output\n        tensor([ 1,  2,  3])\n        >>> inverse_indices\n        tensor([ 0,  2,  1,  2])\n\n        >>> output, inverse_indices = torch.unique(\n                torch.tensor([[1, 3], [2, 3]], dtype=torch.long), sorted=True, return_inverse=True)\n        >>> output\n        tensor([ 1,  2,  3])\n        >>> inverse_indices\n        tensor([[ 0,  2],\n                [ 1,  2]])\n\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return handle_torch_function(\n                unique, (input,), input, sorted=sorted, return_inverse=return_inverse,\n                return_counts=return_counts, dim=dim)\n\n    if dim is not None:\n        output, inverse_indices, counts = _VF.unique_dim(  # type: ignore\n            input,\n            dim,\n            sorted=sorted,\n            return_inverse=return_inverse,\n            return_counts=return_counts,\n        )\n    else:\n        output, inverse_indices, counts = torch._unique2(\n            input,\n            sorted=sorted,\n            return_inverse=return_inverse,\n            return_counts=return_counts,\n        )\n    return output, inverse_indices, counts\n\n\ndef _unique_consecutive_impl(input: Tensor, return_inverse: bool = False,\n                             return_counts: bool = False,\n                             dim: Optional[int] = None) -> _unique_impl_out:\n    r\"\"\"Eliminates all but the first element from every consecutive group of equivalent elements.\n\n    .. note:: This function is different from :func:`torch.unique` in the sense that this function\n        only eliminates consecutive duplicate values. This semantics is similar to `std::unique`\n        in C++.\n\n    Arguments:\n        input (Tensor): the input tensor\n        return_inverse (bool): Whether to also return the indices for where\n            elements in the original input ended up in the returned unique list.\n        return_counts (bool): Whether to also return the counts for each unique\n            element.\n        dim (int): the dimension to apply unique. If ``None``, the unique of the\n            flattened input is returned. default: ``None``\n\n    Returns:\n        (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing\n\n            - **output** (*Tensor*): the output list of unique scalar elements.\n            - **inverse_indices** (*Tensor*): (optional) if\n              :attr:`return_inverse` is True, there will be an additional\n              returned tensor (same shape as input) representing the indices\n              for where elements in the original input map to in the output;\n              otherwise, this function will only return a single tensor.\n            - **counts** (*Tensor*): (optional) if\n              :attr:`return_counts` is True, there will be an additional\n              returned tensor (same shape as output or output.size(dim),\n              if dim was specified) representing the number of occurrences\n              for each unique value or tensor.\n\n    Example::\n\n        >>> x = torch.tensor([1, 1, 2, 2, 3, 1, 1, 2])\n        >>> output = torch.unique_consecutive(x)\n        >>> output\n        tensor([1, 2, 3, 1, 2])\n\n        >>> output, inverse_indices = torch.unique_consecutive(x, return_inverse=True)\n        >>> output\n        tensor([1, 2, 3, 1, 2])\n        >>> inverse_indices\n        tensor([0, 0, 1, 1, 2, 3, 3, 4])\n\n        >>> output, counts = torch.unique_consecutive(x, return_counts=True)\n        >>> output\n        tensor([1, 2, 3, 1, 2])\n        >>> counts\n        tensor([2, 2, 1, 2, 1])\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return handle_torch_function(\n                unique_consecutive, (input,), input, return_inverse=return_inverse,\n                return_counts=return_counts, dim=dim)\n    output, inverse_indices, counts = _VF.unique_consecutive(  # type: ignore\n        input, return_inverse=return_inverse, return_counts=return_counts, dim=dim)\n    return output, inverse_indices, counts\n\n\ndef _return_counts(input, sorted=True, return_inverse=False, return_counts=False, dim=None):\n    # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]\n\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return _unique_impl(input, sorted, return_inverse, return_counts, dim)\n\n    output, _, counts = _unique_impl(input, sorted, return_inverse, return_counts, dim)\n    return output, counts\n\n\ndef _return_output(input, sorted=True, return_inverse=False, return_counts=False, dim=None):\n    # type: (Tensor, bool, bool, bool, Optional[int]) -> Tensor\n\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return _unique_impl(input, sorted, return_inverse, return_counts, dim)\n\n    output, _, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)\n    return output\n\n\ndef _return_inverse(input, sorted=True, return_inverse=False, return_counts=False, dim=None):\n    # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]\n\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return _unique_impl(input, sorted, return_inverse, return_counts, dim)\n\n    output, inverse_indices, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)\n    return output, inverse_indices\n\n\n_return_inverse_false = boolean_dispatch(\n    arg_name='return_counts',\n    arg_index=3,\n    default=False,\n    if_true=_return_counts,\n    if_false=_return_output,\n    module_name=__name__,\n    func_name='unique')\n\n_return_inverse_true = boolean_dispatch(\n    arg_name='return_counts',\n    arg_index=3,\n    default=False,\n    if_true=_unique_impl,\n    if_false=_return_inverse,\n    module_name=__name__,\n    func_name='unique')\n\n# The return type of unique depends on `return_inverse`, and `return_counts` so in order to\n# resolve the output type in TorchScript we need to statically know the value of both parameters\n\nunique = boolean_dispatch(\n    arg_name='return_inverse',\n    arg_index=2,\n    default=False,\n    if_true=_return_inverse_true,\n    if_false=_return_inverse_false,\n    module_name=__name__,\n    func_name='unique')\nunique.__doc__ = _unique_impl.__doc__\n\n\ndef _consecutive_return_counts(input, return_inverse=False, return_counts=False, dim=None):\n    # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]\n\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return _unique_consecutive_impl(input, return_inverse, return_counts, dim)\n\n    output, _, counts = _unique_consecutive_impl(input, return_inverse, return_counts, dim)\n    return output, counts\n\n\ndef _consecutive_return_output(input, return_inverse=False, return_counts=False, dim=None):\n    # type: (Tensor, bool, bool, Optional[int]) -> Tensor\n\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return _unique_consecutive_impl(input, return_inverse, return_counts, dim)\n\n    output, _, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)\n    return output\n\n\ndef _consecutive_return_inverse(input, return_inverse=False, return_counts=False, dim=None):\n    # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]\n\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return _unique_consecutive_impl(input, return_inverse, return_counts, dim)\n\n    output, inverse_indices, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)\n    return output, inverse_indices\n\n\n_consecutive_return_inverse_false = boolean_dispatch(\n    arg_name='return_counts',\n    arg_index=1,\n    default=False,\n    if_true=_consecutive_return_counts,\n    if_false=_consecutive_return_output,\n    module_name=__name__,\n    func_name='unique_consecutive')\n\n_consecutive_return_inverse_true = boolean_dispatch(\n    arg_name='return_counts',\n    arg_index=1,\n    default=False,\n    if_true=_unique_consecutive_impl,\n    if_false=_consecutive_return_inverse,\n    module_name=__name__,\n    func_name='unique_consecutive')\n\n# The return type of unique depends on `return_inverse`, and `return_counts` so in order to\n# resolve the output type in TorchScript we need to statically know the value of both parameters\n\nunique_consecutive = boolean_dispatch(\n    arg_name='return_inverse',\n    arg_index=2,\n    default=False,\n    if_true=_consecutive_return_inverse_true,\n    if_false=_consecutive_return_inverse_false,\n    module_name=__name__,\n    func_name='unique_consecutive')\nunique_consecutive.__doc__ = _unique_consecutive_impl.__doc__\n\n\ndef tensordot(a, b, dims=2):\n    r\"\"\"Returns a contraction of a and b over multiple dimensions.\n\n    :attr:`tensordot` implements a generalized matrix product.\n\n    Args:\n      a (Tensor): Left tensor to contract\n      b (Tensor): Right tensor to contract\n      dims (int or tuple of two lists of integers): number of dimensions to\n         contract or explicit lists of dimensions for :attr:`a` and\n         :attr:`b` respectively\n\n    When called with a non-negative integer argument :attr:`dims` = :math:`d`, and\n    the number of dimensions of :attr:`a` and :attr:`b` is :math:`m` and :math:`n`,\n    respectively, :func:`~torch.tensordot` computes\n\n    .. math::\n        r_{i_0,...,i_{m-d}, i_d,...,i_n}\n          = \\sum_{k_0,...,k_{d-1}} a_{i_0,...,i_{m-d},k_0,...,k_{d-1}} \\times b_{k_0,...,k_{d-1}, i_d,...,i_n}.\n\n    When called with :attr:`dims` of the list form, the given dimensions will be contracted\n    in place of the last :math:`d` of :attr:`a` and the first :math:`d` of :math:`b`. The sizes\n    in these dimensions must match, but :func:`~torch.tensordot` will deal with broadcasted\n    dimensions.\n\n    Examples::\n\n        >>> a = torch.arange(60.).reshape(3, 4, 5)\n        >>> b = torch.arange(24.).reshape(4, 3, 2)\n        >>> torch.tensordot(a, b, dims=([1, 0], [0, 1]))\n        tensor([[4400., 4730.],\n                [4532., 4874.],\n                [4664., 5018.],\n                [4796., 5162.],\n                [4928., 5306.]])\n\n        >>> a = torch.randn(3, 4, 5, device='cuda')\n        >>> b = torch.randn(4, 5, 6, device='cuda')\n        >>> c = torch.tensordot(a, b, dims=2).cpu()\n        tensor([[ 8.3504, -2.5436,  6.2922,  2.7556, -1.0732,  3.2741],\n                [ 3.3161,  0.0704,  5.0187, -0.4079, -4.3126,  4.8744],\n                [ 0.8223,  3.9445,  3.2168, -0.2400,  3.4117,  1.7780]])\n\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if (type(a) is not Tensor or type(b) is not Tensor) and has_torch_function((a, b)):\n            return handle_torch_function(tensordot, (a, b), a, b, dims=dims)\n    if isinstance(dims, (list, tuple)) or \\\n       (isinstance(dims, torch.Tensor) and dims.numel() > 1):\n        dims_a, dims_b = dims\n    else:\n        if isinstance(dims, torch.Tensor):\n            dims = dims.item()\n        if dims < 0:\n            raise RuntimeError(f\"tensordot expects dims >= 0, but got dims={dims}\")\n        dims_a = list(range(-dims, 0))\n        dims_b = list(range(dims))\n    return _VF.tensordot(a, b, dims_a, dims_b)  # type: ignore\n\ndef cartesian_prod(*tensors):\n    \"\"\"Do cartesian product of the given sequence of tensors. The behavior is similar to\n    python's `itertools.product`.\n\n    Arguments:\n        *tensors: any number of 1 dimensional tensors.\n\n    Returns:\n        Tensor: A tensor equivalent to converting all the input tensors into lists,\n            do `itertools.product` on these lists, and finally convert the resulting list\n            into tensor.\n\n    Example::\n\n        >>> a = [1, 2, 3]\n        >>> b = [4, 5]\n        >>> list(itertools.product(a, b))\n        [(1, 4), (1, 5), (2, 4), (2, 5), (3, 4), (3, 5)]\n        >>> tensor_a = torch.tensor(a)\n        >>> tensor_b = torch.tensor(b)\n        >>> torch.cartesian_prod(tensor_a, tensor_b)\n        tensor([[1, 4],\n                [1, 5],\n                [2, 4],\n                [2, 5],\n                [3, 4],\n                [3, 5]])\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):\n            return handle_torch_function(cartesian_prod, tensors, *tensors)\n    return _VF.cartesian_prod(tensors)  # type: ignore\n\ndef block_diag(*tensors):\n    \"\"\"Create a block diagonal matrix from provided tensors.\n\n    Arguments:\n        *tensors: One or more tensors with 0, 1, or 2 dimensions.\n\n    Returns:\n        Tensor: A 2 dimensional tensor with all the input tensors arranged in\n            order such that their upper left and lower right corners are\n            diagonally adjacent. All other elements are set to 0.\n\n    Example::\n\n        >>> import torch\n        >>> A = torch.tensor([[0, 1], [1, 0]])\n        >>> B = torch.tensor([[3, 4, 5], [6, 7, 8]])\n        >>> C = torch.tensor(7)\n        >>> D = torch.tensor([1, 2, 3])\n        >>> E = torch.tensor([[4], [5], [6]])\n        >>> torch.block_diag(A, B, C, D, E)\n        tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],\n                [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n                [0, 0, 3, 4, 5, 0, 0, 0, 0, 0],\n                [0, 0, 6, 7, 8, 0, 0, 0, 0, 0],\n                [0, 0, 0, 0, 0, 7, 0, 0, 0, 0],\n                [0, 0, 0, 0, 0, 0, 1, 2, 3, 0],\n                [0, 0, 0, 0, 0, 0, 0, 0, 0, 4],\n                [0, 0, 0, 0, 0, 0, 0, 0, 0, 5],\n                [0, 0, 0, 0, 0, 0, 0, 0, 0, 6]])\n    \"\"\"\n    if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):\n        return handle_torch_function(block_diag, tensors, *tensors)\n    return torch._C._VariableFunctions.block_diag(tensors)  # type: ignore\n\n\ndef cdist(x1, x2, p=2., compute_mode='use_mm_for_euclid_dist_if_necessary'):\n    # type: (Tensor, Tensor, float, str) -> (Tensor)\n    r\"\"\"Computes batched the p-norm distance between each pair of the two collections of row vectors.\n\n    Args:\n        x1 (Tensor): input tensor of shape :math:`B \\times P \\times M`.\n        x2 (Tensor): input tensor of shape :math:`B \\times R \\times M`.\n        p: p value for the p-norm distance to calculate between each vector pair\n            :math:`\\in [0, \\infty]`.\n        compute_mode:\n            'use_mm_for_euclid_dist_if_necessary' - will use matrix multiplication approach to calculate\n            euclidean distance (p = 2) if P > 25 or R > 25\n            'use_mm_for_euclid_dist' - will always use matrix multiplication approach to calculate\n            euclidean distance (p = 2)\n            'donot_use_mm_for_euclid_dist' - will never use matrix multiplication approach to calculate\n            euclidean distance (p = 2)\n            Default: use_mm_for_euclid_dist_if_necessary.\n\n    If x1 has shape :math:`B \\times P \\times M` and x2 has shape :math:`B \\times R \\times M` then the\n    output will have shape :math:`B \\times P \\times R`.\n\n    This function is equivalent to `scipy.spatial.distance.cdist(input,'minkowski', p=p)`\n    if :math:`p \\in (0, \\infty)`. When :math:`p = 0` it is equivalent to\n    `scipy.spatial.distance.cdist(input, 'hamming') * M`. When :math:`p = \\infty`, the closest\n    scipy function is `scipy.spatial.distance.cdist(xn, lambda x, y: np.abs(x - y).max())`.\n\n    Example:\n\n        >>> a = torch.tensor([[0.9041,  0.0196], [-0.3108, -2.4423], [-0.4821,  1.059]])\n        >>> a\n        tensor([[ 0.9041,  0.0196],\n                [-0.3108, -2.4423],\n                [-0.4821,  1.0590]])\n        >>> b = torch.tensor([[-2.1763, -0.4713], [-0.6986,  1.3702]])\n        >>> b\n        tensor([[-2.1763, -0.4713],\n                [-0.6986,  1.3702]])\n        >>> torch.cdist(a, b, p=2)\n        tensor([[3.1193, 2.0959],\n                [2.7138, 3.8322],\n                [2.2830, 0.3791]])\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if (type(x1) is not Tensor or type(x2) is not Tensor) and has_torch_function((x1, x2)):\n            return handle_torch_function(\n                cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode)\n    if compute_mode == 'use_mm_for_euclid_dist_if_necessary':\n        return _VF.cdist(x1, x2, p, None)  # type: ignore\n    elif compute_mode == 'use_mm_for_euclid_dist':\n        return _VF.cdist(x1, x2, p, 1)  # type: ignore\n    elif compute_mode == 'donot_use_mm_for_euclid_dist':\n        return _VF.cdist(x1, x2, p, 2)  # type: ignore\n    else:\n        raise ValueError(f\"{compute_mode} is not a valid value for compute_mode\")\n\ndef atleast_1d(*tensors):\n    r\"\"\"\n    Returns a 1-dimensional view of each input tensor with zero dimensions.\n    Input tensors with one or more dimensions are returned as-is.\n\n    Args:\n        input (Tensor or list of Tensors)\n\n    Returns:\n        output (Tensor or tuple of Tensors)\n\n    Example::\n        >>> x = torch.randn(2)\n        >>> x\n        tensor([1.4584, 0.7583])\n        >>> torch.atleast_1d(x)\n        tensor([1.4584, 0.7583])\n        >>> x = torch.tensor(1.)\n        >>> x\n        tensor(1.)\n        >>> torch.atleast_1d(x)\n        tensor([1.])\n        >>> x = torch.tensor(0.5)\n        >>> y = torch.tensor(1.)\n        >>> torch.atleast_1d((x,y))\n        (tensor([0.5000]), tensor([1.]))\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):\n            return handle_torch_function(atleast_1d, tensors, *tensors)\n    if len(tensors) == 1:\n        tensors = tensors[0]\n    return _VF.atleast_1d(tensors)  # type: ignore\n\ndef atleast_2d(*tensors):\n    r\"\"\"\n    Returns a 2-dimensional view of each each input tensor with zero dimensions.\n    Input tensors with two or more dimensions are returned as-is.\n    Args:\n        input (Tensor or list of Tensors)\n\n    Returns:\n        output (Tensor or tuple of Tensors)\n\n    Example::\n        >>> x = torch.tensor(1.)\n        >>> x\n        tensor(1.)\n        >>> torch.atleast_2d(x)\n        tensor([[1.]])\n        >>> x = torch.randn(2,2)\n        >>> x\n        tensor([[2.2086, 2.5165],\n                [0.1757, 0.5194]])\n        >>> torch.atleast_2d(x)\n        tensor([[2.2086, 2.5165],\n                [0.1757, 0.5194]])\n        >>> x = torch.tensor(0.5)\n        >>> y = torch.tensor(1.)\n        >>> torch.atleast_2d((x,y))\n        (tensor([[0.5000]]), tensor([[1.]]))\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):\n            return handle_torch_function(atleast_2d, tensors, *tensors)\n    if len(tensors) == 1:\n        tensors = tensors[0]\n    return _VF.atleast_2d(tensors)  # type: ignore\n\ndef atleast_3d(*tensors):\n    r\"\"\"\n    Returns a 3-dimensional view of each each input tensor with zero dimensions.\n    Input tensors with three or more dimensions are returned as-is.\n    Args:\n        input (Tensor or list of Tensors)\n\n    Returns:\n        output (Tensor or tuple of Tensors)\n\n    Example:\n\n        >>> x = torch.tensor(0.5)\n        >>> x\n        tensor(0.5000)\n        >>> torch.atleast_3d(x)\n        tensor([[[0.5000]]])\n        >>> y = torch.randn(2,2)\n        >>> y\n        tensor([[-0.8079,  0.7460],\n                [-1.1647,  1.4734]])\n        >>> torch.atleast_3d(y)\n        tensor([[[-0.8079],\n                [ 0.7460]],\n                <BLANKLINE>\n                [[-1.1647],\n                [ 1.4734]]])\n        >>> x = torch.randn(1,1,1)\n        >>> x\n        tensor([[[-1.5689]]])\n        >>> torch.atleast_3d(x)\n        tensor([[[-1.5689]]])\n        >>> x = torch.tensor(0.5)\n        >>> y = torch.tensor(1.)\n        >>> torch.atleast_3d((x,y))\n        (tensor([[[0.5000]]]), tensor([[[1.]]]))\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):\n            return handle_torch_function(atleast_3d, tensors, *tensors)\n    if len(tensors) == 1:\n        tensors = tensors[0]\n    return _VF.atleast_3d(tensors)  # type: ignore\n\n\nif TYPE_CHECKING:\n    pass\n    # There's no good way to use this type annotation; cannot rename norm() to\n    # _norm_impl() in a way that doesn't break JIT overloads. So leave untyped\n    # for mypy for now.\n    #    def norm(input: Tensor,\n    #             p: Optional[Union[str, Number]] = \"fro\",\n    #             dim: Optional[Union[int, List[int]]] = None,\n    #             keepdim: bool = False,\n    #             out: Optional[Tensor] = None,\n    #             dtype: _dtype = None) -> Tensor:\n    #        return _norm_impl(input, p, dim, keepdim, out, dtype)\nelse:\n    # TODO: type dim as BroadcastingList when\n    # https://github.com/pytorch/pytorch/issues/33782 is fixed\n    @overload  # noqa: 749\n    def norm(input, p=\"fro\", dim=None, keepdim=False, out=None, dtype=None):  # noqa: 749\n        # type: (Tensor, str, Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor\n        pass\n\n    @overload  # noqa: 749\n    def norm(input, p=\"fro\", dim=None, keepdim=False, out=None, dtype=None):  # noqa: 749\n        # type: (Tensor, Optional[number], Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor\n        pass\n\n    @overload  # noqa: 749\n    def norm(input, p=\"fro\", dim=None, keepdim=False, out=None, dtype=None):  # noqa: 749\n        # type: (Tensor, Optional[number], Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor\n        pass\n\n    @overload  # noqa: 749\n    def norm(input, p=\"fro\", dim=None, keepdim=False, out=None, dtype=None):  # noqa: 749\n        # type: (Tensor, str, Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor\n        pass\n\n\ndef norm(input, p=\"fro\", dim=None, keepdim=False, out=None, dtype=None):  # noqa: 749\n    r\"\"\"Returns the matrix norm or vector norm of a given tensor.\n\n    .. warning::\n\n        torch.norm is deprecated and may be removed in a future PyTorch release.\n        Use :func:`torch.linalg.norm` instead, but note that :func:`torch.linalg.norm`\n        has a different signature and slightly different behavior that is\n        more consistent with NumPy's numpy.linalg.norm.\n\n    Args:\n        input (Tensor): the input tensor\n        p (int, float, inf, -inf, 'fro', 'nuc', optional): the order of norm. Default: ``'fro'``\n            The following norms can be calculated:\n\n            =====  ============================  ==========================\n            ord    matrix norm                   vector norm\n            =====  ============================  ==========================\n            None   Frobenius norm                2-norm\n            'fro'  Frobenius norm                --\n            'nuc'  nuclear norm                  --\n            Other  as vec norm when dim is None  sum(abs(x)**ord)**(1./ord)\n            =====  ============================  ==========================\n\n        dim (int, 2-tuple of ints, 2-list of ints, optional): If it is an int,\n            vector norm will be calculated, if it is 2-tuple of ints, matrix norm\n            will be calculated. If the value is None, matrix norm will be calculated\n            when the input tensor only has two dimensions, vector norm will be\n            calculated when the input tensor only has one dimension. If the input\n            tensor has more than two dimensions, the vector norm will be applied to\n            last dimension.\n        keepdim (bool, optional): whether the output tensors have :attr:`dim`\n            retained or not. Ignored if :attr:`dim` = ``None`` and\n            :attr:`out` = ``None``. Default: ``False``\n        out (Tensor, optional): the output tensor. Ignored if\n            :attr:`dim` = ``None`` and :attr:`out` = ``None``.\n        dtype (:class:`torch.dtype`, optional): the desired data type of\n            returned tensor. If specified, the input tensor is casted to\n            :attr:'dtype' while performing the operation. Default: None.\n\n\n    Example::\n\n        >>> import torch\n        >>> a = torch.arange(9, dtype= torch.float) - 4\n        >>> b = a.reshape((3, 3))\n        >>> torch.norm(a)\n        tensor(7.7460)\n        >>> torch.norm(b)\n        tensor(7.7460)\n        >>> torch.norm(a, float('inf'))\n        tensor(4.)\n        >>> torch.norm(b, float('inf'))\n        tensor(4.)\n        >>> c = torch.tensor([[ 1, 2, 3],[-1, 1, 4]] , dtype= torch.float)\n        >>> torch.norm(c, dim=0)\n        tensor([1.4142, 2.2361, 5.0000])\n        >>> torch.norm(c, dim=1)\n        tensor([3.7417, 4.2426])\n        >>> torch.norm(c, p=1, dim=1)\n        tensor([6., 6.])\n        >>> d = torch.arange(8, dtype= torch.float).reshape(2,2,2)\n        >>> torch.norm(d, dim=(1,2))\n        tensor([ 3.7417, 11.2250])\n        >>> torch.norm(d[0, :, :]), torch.norm(d[1, :, :])\n        (tensor(3.7417), tensor(11.2250))\n    \"\"\"\n\n    if not torch.jit.is_scripting():\n        if type(input) is not Tensor and has_torch_function((input,)):\n            return handle_torch_function(\n                norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype)\n\n    ndim = input.dim()\n\n    # catch default case\n    if dim is None and out is None and dtype is None and p is not None:\n        if isinstance(p, str):\n            if p == \"fro\":\n                return _VF.frobenius_norm(input, dim=(), keepdim=keepdim)  # type: ignore\n        if not isinstance(p, str):\n            _dim = [i for i in range(ndim)]  # noqa: C416 TODO: rewrite as list(range(m))\n            return _VF.norm(input, p, dim=_dim, keepdim=keepdim)  # type: ignore\n\n    # TODO: when https://github.com/pytorch/pytorch/issues/33782 is fixed\n    # remove the overloads where dim is an int and replace with BraodcastingList1\n    # and remove next four lines, replace _dim with dim\n    if dim is not None:\n        if isinstance(dim, int):\n            _dim = [dim]\n        else:\n            _dim = dim\n    else:\n        _dim = None  # type: ignore\n\n    if isinstance(p, str):\n        if p == \"fro\":\n            if dtype is not None:\n                raise ValueError(\"dtype argument is not supported in frobenius norm\")\n\n            if _dim is None:\n                _dim = [i for i in range(ndim)]  # noqa: C416 TODO: rewrite as list(range(m))\n            if out is None:\n                return _VF.frobenius_norm(input, _dim, keepdim=keepdim)  # type: ignore\n            else:\n                return _VF.frobenius_norm(input, _dim, keepdim=keepdim, out=out)  # type: ignore\n        elif p == \"nuc\":\n            if dtype is not None:\n                raise ValueError(\"dtype argument is not supported in nuclear norm\")\n            if _dim is None:\n                if out is None:\n                    return _VF.nuclear_norm(input, keepdim=keepdim)  # type: ignore\n                else:\n                    return _VF.nuclear_norm(input, keepdim=keepdim, out=out)  # type: ignore\n            else:\n                if out is None:\n                    return _VF.nuclear_norm(input, _dim, keepdim=keepdim)  # type: ignore\n                else:\n                    return _VF.nuclear_norm(input, _dim, keepdim=keepdim, out=out)  # type: ignore\n        raise RuntimeError(f\"only valid string values are 'fro' and 'nuc', found {p}\")\n    else:\n        if _dim is None:\n            _dim = [i for i in range(ndim)]  # noqa: C416 TODO: rewrite as list(range(m))\n\n        if out is None:\n            if dtype is None:\n                return _VF.norm(input, p, _dim, keepdim=keepdim)  # type: ignore\n            else:\n                return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype)  # type: ignore\n        else:\n            if dtype is None:\n                return _VF.norm(input, p, _dim, keepdim=keepdim, out=out)  # type: ignore\n            else:\n                return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype, out=out)  # type: ignore\n\ndef chain_matmul(*matrices):\n    r\"\"\"Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed\n    using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms\n    of arithmetic operations (`[CLRS]`_). Note that since this is a function to compute the product, :math:`N`\n    needs to be greater than or equal to 2; if equal to 2 then a trivial matrix-matrix product is returned.\n    If :math:`N` is 1, then this is a no-op - the original matrix is returned as is.\n\n\n    Args:\n        matrices (Tensors...): a sequence of 2 or more 2-D tensors whose product is to be determined.\n\n\n    Returns:\n        Tensor: if the :math:`i^{th}` tensor was of dimensions :math:`p_{i} \\times p_{i + 1}`, then the product\n        would be of dimensions :math:`p_{1} \\times p_{N + 1}`.\n\n    Example::\n\n        >>> a = torch.randn(3, 4)\n        >>> b = torch.randn(4, 5)\n        >>> c = torch.randn(5, 6)\n        >>> d = torch.randn(6, 7)\n        >>> torch.chain_matmul(a, b, c, d)\n        tensor([[ -2.3375,  -3.9790,  -4.1119,  -6.6577,   9.5609, -11.5095,  -3.2614],\n                [ 21.4038,   3.3378,  -8.4982,  -5.2457, -10.2561,  -2.4684,   2.7163],\n                [ -0.9647,  -5.8917,  -2.3213,  -5.2284,  12.8615, -12.2816,  -2.5095]])\n\n    .. _`[CLRS]`: https://mitpress.mit.edu/books/introduction-algorithms-third-edition\n    \"\"\"\n    if not torch.jit.is_scripting():\n        if any(type(t) is not Tensor for t in matrices) and has_torch_function(matrices):\n            return handle_torch_function(chain_matmul, matrices, *matrices)\n    return _VF.chain_matmul(matrices)  # type: ignore\n\n\ndef _lu_impl(A, pivot=True, get_infos=False, out=None):\n    # type: (Tensor, bool, bool, Any) -> Tuple[Tensor, Tensor, Tensor]\n    r\"\"\"Computes the LU factorization of a matrix or batches of matrices\n    :attr:`A`. Returns a tuple containing the LU factorization and\n    pivots of :attr:`A`.  Pivoting is done if :attr:`pivot` is set to\n    ``True``.\n\n    .. note::\n        The pivots returned by the function are 1-indexed. If :attr:`pivot` is ``False``,\n        then the returned pivots is a tensor filled with zeros of the appropriate size.\n\n    .. note::\n        LU factorization with :attr:`pivot` = ``False`` is not available for CPU, and attempting\n        to do so will throw an error. However, LU factorization with :attr:`pivot` = ``False`` is\n        available for CUDA.\n\n    .. note::\n        This function does not check if the factorization was successful or not if\n        :attr:`get_infos` is ``True`` since the status of the factorization is present in the\n        third element of the return tuple.\n\n    .. note::\n        In the case of batches of square matrices with size less or\n        equal to 32 on a CUDA device, the LU factorization is repeated\n        for singular matrices due to the bug in the MAGMA library (see\n        magma issue 13).\n\n    .. note::\n       ``L``, ``U``, and ``P`` can be derived using :func:`torch.lu_unpack`.\n\n    Arguments:\n        A (Tensor): the tensor to factor of size :math:`(*, m, n)`\n        pivot (bool, optional): controls whether pivoting is done. Default: ``True``\n        get_infos (bool, optional): if set to ``True``, returns an info IntTensor.\n                                    Default: ``False``\n        out (tuple, optional): optional output tuple. If :attr:`get_infos` is ``True``,\n                               then the elements in the tuple are Tensor, IntTensor,\n                               and IntTensor. If :attr:`get_infos` is ``False``, then the\n                               elements in the tuple are Tensor, IntTensor. Default: ``None``\n\n    Returns:\n        (Tensor, IntTensor, IntTensor (optional)): A tuple of tensors containing\n\n            - **factorization** (*Tensor*): the factorization of size :math:`(*, m, n)`\n\n            - **pivots** (*IntTensor*): the pivots of size :math:`(*, m)`\n\n            - **infos** (*IntTensor*, *optional*): if :attr:`get_infos` is ``True``, this is a tensor of\n              size :math:`(*)` where non-zero values indicate whether factorization for the matrix or\n              each minibatch has succeeded or failed\n\n    Example::\n\n        >>> A = torch.randn(2, 3, 3)\n        >>> A_LU, pivots = torch.lu(A)\n        >>> A_LU\n        tensor([[[ 1.3506,  2.5558, -0.0816],\n                 [ 0.1684,  1.1551,  0.1940],\n                 [ 0.1193,  0.6189, -0.5497]],\n\n                [[ 0.4526,  1.2526, -0.3285],\n                 [-0.7988,  0.7175, -0.9701],\n                 [ 0.2634, -0.9255, -0.3459]]])\n        >>> pivots\n        tensor([[ 3,  3,  3],\n                [ 3,  3,  3]], dtype=torch.int32)\n        >>> A_LU, pivots, info = torch.lu(A, get_infos=True)\n        >>> if info.nonzero().size(0) == 0:\n        ...   print('LU factorization succeeded for all samples!')\n        LU factorization succeeded for all samples!\n    \"\"\"\n    # If get_infos is True, then we don't need to check for errors and vice versa\n    return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))\n\n\nif TYPE_CHECKING:\n    _ListOrSeq = Sequence[Tensor]\nelse:\n    _ListOrSeq = List[Tensor]\n\ndef _check_list_size(out_len: int, get_infos: bool, out: _ListOrSeq) -> None:\n    get_infos_int = 1 if get_infos else 0\n    if out_len - get_infos_int != 2:\n        raise TypeError(f\"expected tuple of {2 + int(get_infos)} elements but got {out_len}\")\n    if not isinstance(out, (tuple, list)):\n        raise TypeError(f\"argument 'out' must be tuple of Tensors, not {type(out).__name__}\")\n\ndef _lu_with_infos(A, pivot=True, get_infos=False, out=None):\n    # type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor]\n    if not torch.jit.is_scripting():\n        if type(A) is not Tensor and has_torch_function((A,)):\n            return handle_torch_function(\n                lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)\n    result = _lu_impl(A, pivot, get_infos, out)\n    if out is not None:\n        _check_list_size(len(out), get_infos, out)\n        for i in range(len(out)):\n            out[i].resize_as_(result[i]).copy_(result[i])\n        return out\n    else:\n        return result  # A_LU, pivots, infos\n\ndef _lu_no_infos(A, pivot=True, get_infos=False, out=None):\n    # type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]\n    # need to check for torch_function here so that we exit if\n    if not torch.jit.is_scripting():\n        if type(A) is not Tensor and has_torch_function((A,)):\n            return handle_torch_function(\n                lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)\n    result = _lu_impl(A, pivot, get_infos, out)\n    if out is not None:\n        _check_list_size(len(out), get_infos, out)\n        for i in range(len(out)):\n            out[i].resize_as_(result[i]).copy_(result[i])\n        return out\n    else:\n        return result[0], result[1]  # A_LU, pivots\n\n# The return type of lu depends on `get_infos`, so in order to resolve the output type\n# of lu in TorchScript we need to statically know the value of `get_infos`\nlu = boolean_dispatch(\n    arg_name='get_infos',\n    arg_index=2,\n    default=False,\n    if_true=_lu_with_infos,\n    if_false=_lu_no_infos,\n    module_name=__name__,\n    func_name='lu')\nlu.__doc__ = _lu_impl.__doc__\n\ndef align_tensors(*tensors):\n    raise RuntimeError('`align_tensors` not yet implemented.')\n"
  },
  {
    "path": "patches/transformers/4.5.0/convert_graph_to_onnx.diff",
    "content": "14a15,17\n> import os \n> import json\n> \n83a87,91\n>             \"--save-config\",\n>             action=\"store_true\",\n>             help=\"Save the model configuration along with the ONNX\",\n>         )\n>         self.add_argument(\n280a289,295\n>         print('Exporting from PyTorch to ONNX...')\n>         print('input_names', input_names)\n>         print('output_names', output_names)\n>         print('dynamic_axes', dynamic_axes)\n>         print('tokens', tokens)\n>         print('model_args', model_args)\n>         \n291a307\n>             verbose=True\n339a356\n>     save_config: bool = False,\n366,367c383,384\n<     elif len(listdir(output.parent.as_posix())) > 0:\n<         raise Exception(f\"Folder {output.parent.as_posix()} is not empty, aborting conversion\")\n---\n>     #elif len(listdir(output.parent.as_posix())) > 0:\n>     #    raise Exception(f\"Folder {output.parent.as_posix()} is not empty, aborting conversion\")\n374c391,407\n< \n---\n>         \n>     # Save the configuration\n>     if save_config:\n>         config_path = os.path.splitext(output)[0] + '.json'\n> \n>         config = dict(\n>             model = nlp.model.config.to_dict(),\n>             tokenizer = nlp.tokenizer.init_kwargs\n>         )\n>         \n>         #nlp.model.config.to_json_file(config_path)\n>         \n>         with open(config_path, 'w') as config_file:\n>             json.dump(config, config_file, indent=2)\n>             \n>         print(f\"Saved config to {config_path}\")\n>         \n468a502\n>             args.save_config\n"
  },
  {
    "path": "patches/transformers/4.5.0/convert_graph_to_onnx.original.py",
    "content": "# Copyright 2020 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom argparse import ArgumentParser\nfrom os import listdir, makedirs\nfrom pathlib import Path\nfrom typing import Dict, List, Optional, Tuple\n\nfrom packaging.version import Version, parse\n\nfrom transformers.file_utils import ModelOutput, is_tf_available, is_torch_available\nfrom transformers.pipelines import Pipeline, pipeline\nfrom transformers.tokenization_utils import BatchEncoding\n\n\n# This is the minimal required version to\n# support some ONNX Runtime features\nORT_QUANTIZE_MINIMUM_VERSION = parse(\"1.4.0\")\n\n\nSUPPORTED_PIPELINES = [\n    \"feature-extraction\",\n    \"ner\",\n    \"sentiment-analysis\",\n    \"fill-mask\",\n    \"question-answering\",\n    \"text-generation\",\n    \"translation_en_to_fr\",\n    \"translation_en_to_de\",\n    \"translation_en_to_ro\",\n]\n\n\nclass OnnxConverterArgumentParser(ArgumentParser):\n    \"\"\"\n    Wraps all the script arguments supported to export transformers models to ONNX IR\n    \"\"\"\n\n    def __init__(self):\n        super().__init__(\"ONNX Converter\")\n\n        self.add_argument(\n            \"--pipeline\",\n            type=str,\n            choices=SUPPORTED_PIPELINES,\n            default=\"feature-extraction\",\n        )\n        self.add_argument(\n            \"--model\",\n            type=str,\n            required=True,\n            help=\"Model's id or path (ex: bert-base-cased)\",\n        )\n        self.add_argument(\"--tokenizer\", type=str, help=\"Tokenizer's id or path (ex: bert-base-cased)\")\n        self.add_argument(\n            \"--framework\",\n            type=str,\n            choices=[\"pt\", \"tf\"],\n            help=\"Framework for loading the model\",\n        )\n        self.add_argument(\"--opset\", type=int, default=11, help=\"ONNX opset to use\")\n        self.add_argument(\n            \"--check-loading\",\n            action=\"store_true\",\n            help=\"Check ONNX is able to load the model\",\n        )\n        self.add_argument(\n            \"--use-external-format\",\n            action=\"store_true\",\n            help=\"Allow exporting model >= than 2Gb\",\n        )\n        self.add_argument(\n            \"--quantize\",\n            action=\"store_true\",\n            help=\"Quantize the neural network to be run with int8\",\n        )\n        self.add_argument(\"output\")\n\n\ndef generate_identified_filename(filename: Path, identifier: str) -> Path:\n    \"\"\"\n    Append a string-identifier at the end (before the extension, if any) to the provided filepath\n\n    Args:\n        filename: pathlib.Path The actual path object we would like to add an identifier suffix\n        identifier: The suffix to add\n\n    Returns: String with concatenated identifier at the end of the filename\n    \"\"\"\n    return filename.parent.joinpath(filename.stem + identifier).with_suffix(filename.suffix)\n\n\ndef check_onnxruntime_requirements(minimum_version: Version):\n    \"\"\"\n    Check onnxruntime is installed and if the installed version match is recent enough\n\n    Raises:\n        ImportError: If onnxruntime is not installed or too old version is found\n    \"\"\"\n    try:\n        import onnxruntime\n\n        # Parse the version of the installed onnxruntime\n        ort_version = parse(onnxruntime.__version__)\n\n        # We require 1.4.0 minimum\n        if ort_version < ORT_QUANTIZE_MINIMUM_VERSION:\n            raise ImportError(\n                f\"We found an older version of onnxruntime ({onnxruntime.__version__}) \"\n                f\"but we require onnxruntime to be >= {minimum_version} to enable all the conversions options.\\n\"\n                f\"Please update onnxruntime by running `pip install --upgrade onnxruntime`\"\n            )\n\n    except ImportError:\n        raise ImportError(\n            \"onnxruntime doesn't seem to be currently installed. \"\n            \"Please install the onnxruntime by running `pip install onnxruntime`\"\n            \" and relaunch the conversion.\"\n        )\n\n\ndef ensure_valid_input(model, tokens, input_names):\n    \"\"\"\n    Ensure input are presented in the correct order, without any Non\n\n    Args:\n        model: The model used to forward the input data\n        tokens: BatchEncoding holding the input data\n        input_names: The name of the inputs\n\n    Returns: Tuple\n\n    \"\"\"\n    print(\"Ensuring inputs are in correct order\")\n\n    model_args_name = model.forward.__code__.co_varnames\n    model_args, ordered_input_names = [], []\n    for arg_name in model_args_name[1:]:  # start at index 1 to skip \"self\" argument\n        if arg_name in input_names:\n            ordered_input_names.append(arg_name)\n            model_args.append(tokens[arg_name])\n        else:\n            print(f\"{arg_name} is not present in the generated input list.\")\n            break\n\n    print(f\"Generated inputs order: {ordered_input_names}\")\n    return ordered_input_names, tuple(model_args)\n\n\ndef infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], Dict, BatchEncoding]:\n    \"\"\"\n    Attempt to infer the static vs dynamic axes for each input and output tensors for a specific model\n\n    Args:\n        nlp: The pipeline object holding the model to be exported\n        framework: The framework identifier to dispatch to the correct inference scheme (pt/tf)\n\n    Returns:\n\n        - List of the inferred input variable names\n        - List of the inferred output variable names\n        - Dictionary with input/output variables names as key and shape tensor as value\n        - a BatchEncoding reference which was used to infer all the above information\n    \"\"\"\n\n    def build_shape_dict(name: str, tensor, is_input: bool, seq_len: int):\n        if isinstance(tensor, (tuple, list)):\n            return [build_shape_dict(name, t, is_input, seq_len) for t in tensor]\n\n        else:\n            # Let's assume batch is the first axis with only 1 element (~~ might not be always true ...)\n            axes = {[axis for axis, numel in enumerate(tensor.shape) if numel == 1][0]: \"batch\"}\n            if is_input:\n                if len(tensor.shape) == 2:\n                    axes[1] = \"sequence\"\n                else:\n                    raise ValueError(f\"Unable to infer tensor axes ({len(tensor.shape)})\")\n            else:\n                seq_axes = [dim for dim, shape in enumerate(tensor.shape) if shape == seq_len]\n                axes.update({dim: \"sequence\" for dim in seq_axes})\n\n        print(f\"Found {'input' if is_input else 'output'} {name} with shape: {axes}\")\n        return axes\n\n    tokens = nlp.tokenizer(\"This is a sample output\", return_tensors=framework)\n    seq_len = tokens.input_ids.shape[-1]\n    outputs = nlp.model(**tokens) if framework == \"pt\" else nlp.model(tokens)\n    if isinstance(outputs, ModelOutput):\n        outputs = outputs.to_tuple()\n    if not isinstance(outputs, (list, tuple)):\n        outputs = (outputs,)\n\n    # Generate input names & axes\n    input_vars = list(tokens.keys())\n    input_dynamic_axes = {k: build_shape_dict(k, v, True, seq_len) for k, v in tokens.items()}\n\n    # flatten potentially grouped outputs (past for gpt2, attentions)\n    outputs_flat = []\n    for output in outputs:\n        if isinstance(output, (tuple, list)):\n            outputs_flat.extend(output)\n        else:\n            outputs_flat.append(output)\n\n    # Generate output names & axes\n    output_names = [f\"output_{i}\" for i in range(len(outputs_flat))]\n    output_dynamic_axes = {k: build_shape_dict(k, v, False, seq_len) for k, v in zip(output_names, outputs_flat)}\n\n    # Create the aggregated axes representation\n    dynamic_axes = dict(input_dynamic_axes, **output_dynamic_axes)\n    return input_vars, output_names, dynamic_axes, tokens\n\n\ndef load_graph_from_args(\n    pipeline_name: str, framework: str, model: str, tokenizer: Optional[str] = None, **models_kwargs\n) -> Pipeline:\n    \"\"\"\n    Convert the set of arguments provided through the CLI to an actual pipeline reference (tokenizer + model\n\n    Args:\n        pipeline_name: The kind of pipeline to use (ner, question-answering, etc.)\n        framework: The actual model to convert the pipeline from (\"pt\" or \"tf\")\n        model: The model name which will be loaded by the pipeline\n        tokenizer: The tokenizer name which will be loaded by the pipeline, default to the model's value\n\n    Returns: Pipeline object\n\n    \"\"\"\n    # If no tokenizer provided\n    if tokenizer is None:\n        tokenizer = model\n\n    # Check the wanted framework is available\n    if framework == \"pt\" and not is_torch_available():\n        raise Exception(\"Cannot convert because PyTorch is not installed. Please install torch first.\")\n    if framework == \"tf\" and not is_tf_available():\n        raise Exception(\"Cannot convert because TF is not installed. Please install tensorflow first.\")\n\n    print(f\"Loading pipeline (model: {model}, tokenizer: {tokenizer})\")\n\n    # Allocate tokenizer and model\n    return pipeline(pipeline_name, model=model, tokenizer=tokenizer, framework=framework, model_kwargs=models_kwargs)\n\n\ndef convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format: bool):\n    \"\"\"\n    Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR\n\n    Args:\n        nlp: The pipeline to be exported\n        opset: The actual version of the ONNX operator set to use\n        output: Path where will be stored the generated ONNX model\n        use_external_format: Split the model definition from its parameters to allow model bigger than 2GB\n\n    Returns:\n\n    \"\"\"\n    if not is_torch_available():\n        raise Exception(\"Cannot convert because PyTorch is not installed. Please install torch first.\")\n\n    import torch\n    from torch.onnx import export\n\n    print(f\"Using framework PyTorch: {torch.__version__}\")\n\n    with torch.no_grad():\n        input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, \"pt\")\n        ordered_input_names, model_args = ensure_valid_input(nlp.model, tokens, input_names)\n\n        export(\n            nlp.model,\n            model_args,\n            f=output.as_posix(),\n            input_names=ordered_input_names,\n            output_names=output_names,\n            dynamic_axes=dynamic_axes,\n            do_constant_folding=True,\n            use_external_data_format=use_external_format,\n            enable_onnx_checker=True,\n            opset_version=opset,\n        )\n\n\ndef convert_tensorflow(nlp: Pipeline, opset: int, output: Path):\n    \"\"\"\n    Export a TensorFlow backed pipeline to ONNX Intermediate Representation (IR\n\n    Args:\n        nlp: The pipeline to be exported\n        opset: The actual version of the ONNX operator set to use\n        output: Path where will be stored the generated ONNX model\n\n    Notes: TensorFlow cannot export model bigger than 2GB due to internal constraint from TensorFlow\n\n    \"\"\"\n    if not is_tf_available():\n        raise Exception(\"Cannot convert because TF is not installed. Please install tensorflow first.\")\n\n    print(\"/!\\\\ Please note TensorFlow doesn't support exporting model > 2Gb /!\\\\\")\n\n    try:\n        import tensorflow as tf\n\n        from keras2onnx import __version__ as k2ov\n        from keras2onnx import convert_keras, save_model\n\n        print(f\"Using framework TensorFlow: {tf.version.VERSION}, keras2onnx: {k2ov}\")\n\n        # Build\n        input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, \"tf\")\n\n        # Forward\n        nlp.model.predict(tokens.data)\n        onnx_model = convert_keras(nlp.model, nlp.model.name, target_opset=opset)\n        save_model(onnx_model, output.as_posix())\n\n    except ImportError as e:\n        raise Exception(f\"Cannot import {e.name} required to convert TF model to ONNX. Please install {e.name} first.\")\n\n\ndef convert(\n    framework: str,\n    model: str,\n    output: Path,\n    opset: int,\n    tokenizer: Optional[str] = None,\n    use_external_format: bool = False,\n    pipeline_name: str = \"feature-extraction\",\n    **model_kwargs\n):\n    \"\"\"\n    Convert the pipeline object to the ONNX Intermediate Representation (IR) format\n\n    Args:\n        framework: The framework the pipeline is backed by (\"pt\" or \"tf\")\n        model: The name of the model to load for the pipeline\n        output: The path where the ONNX graph will be stored\n        opset: The actual version of the ONNX operator set to use\n        tokenizer: The name of the model to load for the pipeline, default to the model's name if not provided\n        use_external_format: Split the model definition from its parameters to allow model bigger than 2GB (PyTorch only)\n        pipeline_name: The kind of pipeline to instantiate (ner, question-answering, etc.)\n        model_kwargs: Keyword arguments to be forwarded to the model constructor\n\n    Returns:\n\n    \"\"\"\n    print(f\"ONNX opset version set to: {opset}\")\n\n    # Load the pipeline\n    nlp = load_graph_from_args(pipeline_name, framework, model, tokenizer, **model_kwargs)\n\n    if not output.parent.exists():\n        print(f\"Creating folder {output.parent}\")\n        makedirs(output.parent.as_posix())\n    elif len(listdir(output.parent.as_posix())) > 0:\n        raise Exception(f\"Folder {output.parent.as_posix()} is not empty, aborting conversion\")\n\n    # Export the graph\n    if framework == \"pt\":\n        convert_pytorch(nlp, opset, output, use_external_format)\n    else:\n        convert_tensorflow(nlp, opset, output)\n\n\ndef optimize(onnx_model_path: Path) -> Path:\n    \"\"\"\n    Load the model at the specified path and let onnxruntime look at transformations on the graph to enable all the\n    optimizations possibl\n\n    Args:\n        onnx_model_path: filepath where the model binary description is stored\n\n    Returns: Path where the optimized model binary description has been saved\n\n    \"\"\"\n    from onnxruntime import InferenceSession, SessionOptions\n\n    # Generate model name with suffix \"optimized\"\n    opt_model_path = generate_identified_filename(onnx_model_path, \"-optimized\")\n    sess_option = SessionOptions()\n    sess_option.optimized_model_filepath = opt_model_path.as_posix()\n    _ = InferenceSession(onnx_model_path.as_posix(), sess_option)\n\n    print(f\"Optimized model has been written at {opt_model_path}: \\N{heavy check mark}\")\n    print(\"/!\\\\ Optimized model contains hardware specific operators which might not be portable. /!\\\\\")\n\n    return opt_model_path\n\n\ndef quantize(onnx_model_path: Path) -> Path:\n    \"\"\"\n    Quantize the weights of the model from float32 to in8 to allow very efficient inference on modern CPU\n\n    Args:\n        onnx_model_path: Path to location the exported ONNX model is stored\n\n    Returns: The Path generated for the quantized\n    \"\"\"\n    import onnx\n    from onnxruntime.quantization import QuantizationMode, quantize\n\n    onnx_model = onnx.load(onnx_model_path.as_posix())\n\n    # Discussed with @yufenglee from ONNX runtime, this will be address in the next release of onnxruntime\n    print(\n        \"As of onnxruntime 1.4.0, models larger than 2GB will fail to quantize due to protobuf constraint.\\n\"\n        \"This limitation will be removed in the next release of onnxruntime.\"\n    )\n\n    quantized_model = quantize(\n        model=onnx_model,\n        quantization_mode=QuantizationMode.IntegerOps,\n        force_fusions=True,\n        symmetric_weight=True,\n    )\n\n    # Append \"-quantized\" at the end of the model's name\n    quantized_model_path = generate_identified_filename(onnx_model_path, \"-quantized\")\n\n    # Save model\n    print(f\"Quantized model has been written at {quantized_model_path}: \\N{heavy check mark}\")\n    onnx.save_model(quantized_model, quantized_model_path.as_posix())\n\n    return quantized_model_path\n\n\ndef verify(path: Path):\n    from onnxruntime import InferenceSession, SessionOptions\n    from onnxruntime.capi.onnxruntime_pybind11_state import RuntimeException\n\n    print(f\"Checking ONNX model loading from: {path} ...\")\n    try:\n        onnx_options = SessionOptions()\n        _ = InferenceSession(path.as_posix(), onnx_options, providers=[\"CPUExecutionProvider\"])\n        print(f\"Model {path} correctly loaded: \\N{heavy check mark}\")\n    except RuntimeException as re:\n        print(f\"Error while loading the model {re}: \\N{heavy ballot x}\")\n\n\nif __name__ == \"__main__\":\n    parser = OnnxConverterArgumentParser()\n    args = parser.parse_args()\n\n    # Make sure output is absolute path\n    args.output = Path(args.output).absolute()\n\n    try:\n        print(\"\\n====== Converting model to ONNX ======\")\n        # Convert\n        convert(\n            args.framework,\n            args.model,\n            args.output,\n            args.opset,\n            args.tokenizer,\n            args.use_external_format,\n            args.pipeline,\n        )\n\n        if args.quantize:\n            # Ensure requirements for quantization on onnxruntime is met\n            check_onnxruntime_requirements(ORT_QUANTIZE_MINIMUM_VERSION)\n\n            # onnxruntime optimizations doesn't provide the same level of performances on TensorFlow than PyTorch\n            if args.framework == \"tf\":\n                print(\n                    \"\\t Using TensorFlow might not provide the same optimization level compared to PyTorch.\\n\"\n                    \"\\t For TensorFlow users you can try optimizing the model directly through onnxruntime_tools.\\n\"\n                    \"\\t For more information, please refer to the onnxruntime documentation:\\n\"\n                    \"\\t\\thttps://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/transformers\\n\"\n                )\n\n            print(\"\\n====== Optimizing ONNX model ======\")\n\n            # Quantization works best when using the optimized version of the model\n            args.optimized_output = optimize(args.output)\n\n            # Do the quantization on the right graph\n            args.quantized_output = quantize(args.optimized_output)\n\n        # And verify\n        if args.check_loading:\n            print(\"\\n====== Check exported ONNX model(s) ======\")\n            verify(args.output)\n\n            if hasattr(args, \"optimized_output\"):\n                verify(args.optimized_output)\n\n            if hasattr(args, \"quantized_output\"):\n                verify(args.quantized_output)\n\n    except Exception as e:\n        print(f\"Error while converting the model: {e}\")\n        exit(1)\n"
  },
  {
    "path": "patches/transformers/4.5.0/convert_graph_to_onnx.py",
    "content": "# Copyright 2020 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os \nimport json\n\nfrom argparse import ArgumentParser\nfrom os import listdir, makedirs\nfrom pathlib import Path\nfrom typing import Dict, List, Optional, Tuple\n\nfrom packaging.version import Version, parse\n\nfrom transformers.file_utils import ModelOutput, is_tf_available, is_torch_available\nfrom transformers.pipelines import Pipeline, pipeline\nfrom transformers.tokenization_utils import BatchEncoding\n\n\n# This is the minimal required version to\n# support some ONNX Runtime features\nORT_QUANTIZE_MINIMUM_VERSION = parse(\"1.4.0\")\n\n\nSUPPORTED_PIPELINES = [\n    \"feature-extraction\",\n    \"ner\",\n    \"sentiment-analysis\",\n    \"fill-mask\",\n    \"question-answering\",\n    \"text-generation\",\n    \"translation_en_to_fr\",\n    \"translation_en_to_de\",\n    \"translation_en_to_ro\",\n]\n\n\nclass OnnxConverterArgumentParser(ArgumentParser):\n    \"\"\"\n    Wraps all the script arguments supported to export transformers models to ONNX IR\n    \"\"\"\n\n    def __init__(self):\n        super().__init__(\"ONNX Converter\")\n\n        self.add_argument(\n            \"--pipeline\",\n            type=str,\n            choices=SUPPORTED_PIPELINES,\n            default=\"feature-extraction\",\n        )\n        self.add_argument(\n            \"--model\",\n            type=str,\n            required=True,\n            help=\"Model's id or path (ex: bert-base-cased)\",\n        )\n        self.add_argument(\"--tokenizer\", type=str, help=\"Tokenizer's id or path (ex: bert-base-cased)\")\n        self.add_argument(\n            \"--framework\",\n            type=str,\n            choices=[\"pt\", \"tf\"],\n            help=\"Framework for loading the model\",\n        )\n        self.add_argument(\"--opset\", type=int, default=11, help=\"ONNX opset to use\")\n        self.add_argument(\n            \"--check-loading\",\n            action=\"store_true\",\n            help=\"Check ONNX is able to load the model\",\n        )\n        self.add_argument(\n            \"--use-external-format\",\n            action=\"store_true\",\n            help=\"Allow exporting model >= than 2Gb\",\n        )\n        self.add_argument(\n            \"--save-config\",\n            action=\"store_true\",\n            help=\"Save the model configuration along with the ONNX\",\n        )\n        self.add_argument(\n            \"--quantize\",\n            action=\"store_true\",\n            help=\"Quantize the neural network to be run with int8\",\n        )\n        self.add_argument(\"output\")\n\n\ndef generate_identified_filename(filename: Path, identifier: str) -> Path:\n    \"\"\"\n    Append a string-identifier at the end (before the extension, if any) to the provided filepath\n\n    Args:\n        filename: pathlib.Path The actual path object we would like to add an identifier suffix\n        identifier: The suffix to add\n\n    Returns: String with concatenated identifier at the end of the filename\n    \"\"\"\n    return filename.parent.joinpath(filename.stem + identifier).with_suffix(filename.suffix)\n\n\ndef check_onnxruntime_requirements(minimum_version: Version):\n    \"\"\"\n    Check onnxruntime is installed and if the installed version match is recent enough\n\n    Raises:\n        ImportError: If onnxruntime is not installed or too old version is found\n    \"\"\"\n    try:\n        import onnxruntime\n\n        # Parse the version of the installed onnxruntime\n        ort_version = parse(onnxruntime.__version__)\n\n        # We require 1.4.0 minimum\n        if ort_version < ORT_QUANTIZE_MINIMUM_VERSION:\n            raise ImportError(\n                f\"We found an older version of onnxruntime ({onnxruntime.__version__}) \"\n                f\"but we require onnxruntime to be >= {minimum_version} to enable all the conversions options.\\n\"\n                f\"Please update onnxruntime by running `pip install --upgrade onnxruntime`\"\n            )\n\n    except ImportError:\n        raise ImportError(\n            \"onnxruntime doesn't seem to be currently installed. \"\n            \"Please install the onnxruntime by running `pip install onnxruntime`\"\n            \" and relaunch the conversion.\"\n        )\n\n\ndef ensure_valid_input(model, tokens, input_names):\n    \"\"\"\n    Ensure input are presented in the correct order, without any Non\n\n    Args:\n        model: The model used to forward the input data\n        tokens: BatchEncoding holding the input data\n        input_names: The name of the inputs\n\n    Returns: Tuple\n\n    \"\"\"\n    print(\"Ensuring inputs are in correct order\")\n\n    model_args_name = model.forward.__code__.co_varnames\n    model_args, ordered_input_names = [], []\n    for arg_name in model_args_name[1:]:  # start at index 1 to skip \"self\" argument\n        if arg_name in input_names:\n            ordered_input_names.append(arg_name)\n            model_args.append(tokens[arg_name])\n        else:\n            print(f\"{arg_name} is not present in the generated input list.\")\n            break\n\n    print(f\"Generated inputs order: {ordered_input_names}\")\n    return ordered_input_names, tuple(model_args)\n\n\ndef infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], Dict, BatchEncoding]:\n    \"\"\"\n    Attempt to infer the static vs dynamic axes for each input and output tensors for a specific model\n\n    Args:\n        nlp: The pipeline object holding the model to be exported\n        framework: The framework identifier to dispatch to the correct inference scheme (pt/tf)\n\n    Returns:\n\n        - List of the inferred input variable names\n        - List of the inferred output variable names\n        - Dictionary with input/output variables names as key and shape tensor as value\n        - a BatchEncoding reference which was used to infer all the above information\n    \"\"\"\n\n    def build_shape_dict(name: str, tensor, is_input: bool, seq_len: int):\n        if isinstance(tensor, (tuple, list)):\n            return [build_shape_dict(name, t, is_input, seq_len) for t in tensor]\n\n        else:\n            # Let's assume batch is the first axis with only 1 element (~~ might not be always true ...)\n            axes = {[axis for axis, numel in enumerate(tensor.shape) if numel == 1][0]: \"batch\"}\n            if is_input:\n                if len(tensor.shape) == 2:\n                    axes[1] = \"sequence\"\n                else:\n                    raise ValueError(f\"Unable to infer tensor axes ({len(tensor.shape)})\")\n            else:\n                seq_axes = [dim for dim, shape in enumerate(tensor.shape) if shape == seq_len]\n                axes.update({dim: \"sequence\" for dim in seq_axes})\n\n        print(f\"Found {'input' if is_input else 'output'} {name} with shape: {axes}\")\n        return axes\n\n    tokens = nlp.tokenizer(\"This is a sample output\", return_tensors=framework)\n    seq_len = tokens.input_ids.shape[-1]\n    outputs = nlp.model(**tokens) if framework == \"pt\" else nlp.model(tokens)\n    if isinstance(outputs, ModelOutput):\n        outputs = outputs.to_tuple()\n    if not isinstance(outputs, (list, tuple)):\n        outputs = (outputs,)\n\n    # Generate input names & axes\n    input_vars = list(tokens.keys())\n    input_dynamic_axes = {k: build_shape_dict(k, v, True, seq_len) for k, v in tokens.items()}\n\n    # flatten potentially grouped outputs (past for gpt2, attentions)\n    outputs_flat = []\n    for output in outputs:\n        if isinstance(output, (tuple, list)):\n            outputs_flat.extend(output)\n        else:\n            outputs_flat.append(output)\n\n    # Generate output names & axes\n    output_names = [f\"output_{i}\" for i in range(len(outputs_flat))]\n    output_dynamic_axes = {k: build_shape_dict(k, v, False, seq_len) for k, v in zip(output_names, outputs_flat)}\n\n    # Create the aggregated axes representation\n    dynamic_axes = dict(input_dynamic_axes, **output_dynamic_axes)\n    return input_vars, output_names, dynamic_axes, tokens\n\n\ndef load_graph_from_args(\n    pipeline_name: str, framework: str, model: str, tokenizer: Optional[str] = None, **models_kwargs\n) -> Pipeline:\n    \"\"\"\n    Convert the set of arguments provided through the CLI to an actual pipeline reference (tokenizer + model\n\n    Args:\n        pipeline_name: The kind of pipeline to use (ner, question-answering, etc.)\n        framework: The actual model to convert the pipeline from (\"pt\" or \"tf\")\n        model: The model name which will be loaded by the pipeline\n        tokenizer: The tokenizer name which will be loaded by the pipeline, default to the model's value\n\n    Returns: Pipeline object\n\n    \"\"\"\n    # If no tokenizer provided\n    if tokenizer is None:\n        tokenizer = model\n\n    # Check the wanted framework is available\n    if framework == \"pt\" and not is_torch_available():\n        raise Exception(\"Cannot convert because PyTorch is not installed. Please install torch first.\")\n    if framework == \"tf\" and not is_tf_available():\n        raise Exception(\"Cannot convert because TF is not installed. Please install tensorflow first.\")\n\n    print(f\"Loading pipeline (model: {model}, tokenizer: {tokenizer})\")\n\n    # Allocate tokenizer and model\n    return pipeline(pipeline_name, model=model, tokenizer=tokenizer, framework=framework, model_kwargs=models_kwargs)\n\n\ndef convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format: bool):\n    \"\"\"\n    Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR\n\n    Args:\n        nlp: The pipeline to be exported\n        opset: The actual version of the ONNX operator set to use\n        output: Path where will be stored the generated ONNX model\n        use_external_format: Split the model definition from its parameters to allow model bigger than 2GB\n\n    Returns:\n\n    \"\"\"\n    if not is_torch_available():\n        raise Exception(\"Cannot convert because PyTorch is not installed. Please install torch first.\")\n\n    import torch\n    from torch.onnx import export\n\n    print(f\"Using framework PyTorch: {torch.__version__}\")\n\n    with torch.no_grad():\n        input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, \"pt\")\n        ordered_input_names, model_args = ensure_valid_input(nlp.model, tokens, input_names)\n\n        print('Exporting from PyTorch to ONNX...')\n        print('input_names', input_names)\n        print('output_names', output_names)\n        print('dynamic_axes', dynamic_axes)\n        print('tokens', tokens)\n        print('model_args', model_args)\n        \n        export(\n            nlp.model,\n            model_args,\n            f=output.as_posix(),\n            input_names=ordered_input_names,\n            output_names=output_names,\n            dynamic_axes=dynamic_axes,\n            do_constant_folding=True,\n            use_external_data_format=use_external_format,\n            enable_onnx_checker=True,\n            opset_version=opset,\n            verbose=True\n        )\n\n\ndef convert_tensorflow(nlp: Pipeline, opset: int, output: Path):\n    \"\"\"\n    Export a TensorFlow backed pipeline to ONNX Intermediate Representation (IR\n\n    Args:\n        nlp: The pipeline to be exported\n        opset: The actual version of the ONNX operator set to use\n        output: Path where will be stored the generated ONNX model\n\n    Notes: TensorFlow cannot export model bigger than 2GB due to internal constraint from TensorFlow\n\n    \"\"\"\n    if not is_tf_available():\n        raise Exception(\"Cannot convert because TF is not installed. Please install tensorflow first.\")\n\n    print(\"/!\\\\ Please note TensorFlow doesn't support exporting model > 2Gb /!\\\\\")\n\n    try:\n        import tensorflow as tf\n\n        from keras2onnx import __version__ as k2ov\n        from keras2onnx import convert_keras, save_model\n\n        print(f\"Using framework TensorFlow: {tf.version.VERSION}, keras2onnx: {k2ov}\")\n\n        # Build\n        input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, \"tf\")\n\n        # Forward\n        nlp.model.predict(tokens.data)\n        onnx_model = convert_keras(nlp.model, nlp.model.name, target_opset=opset)\n        save_model(onnx_model, output.as_posix())\n\n    except ImportError as e:\n        raise Exception(f\"Cannot import {e.name} required to convert TF model to ONNX. Please install {e.name} first.\")\n\n\ndef convert(\n    framework: str,\n    model: str,\n    output: Path,\n    opset: int,\n    tokenizer: Optional[str] = None,\n    use_external_format: bool = False,\n    pipeline_name: str = \"feature-extraction\",\n    save_config: bool = False,\n    **model_kwargs\n):\n    \"\"\"\n    Convert the pipeline object to the ONNX Intermediate Representation (IR) format\n\n    Args:\n        framework: The framework the pipeline is backed by (\"pt\" or \"tf\")\n        model: The name of the model to load for the pipeline\n        output: The path where the ONNX graph will be stored\n        opset: The actual version of the ONNX operator set to use\n        tokenizer: The name of the model to load for the pipeline, default to the model's name if not provided\n        use_external_format: Split the model definition from its parameters to allow model bigger than 2GB (PyTorch only)\n        pipeline_name: The kind of pipeline to instantiate (ner, question-answering, etc.)\n        model_kwargs: Keyword arguments to be forwarded to the model constructor\n\n    Returns:\n\n    \"\"\"\n    print(f\"ONNX opset version set to: {opset}\")\n\n    # Load the pipeline\n    nlp = load_graph_from_args(pipeline_name, framework, model, tokenizer, **model_kwargs)\n\n    if not output.parent.exists():\n        print(f\"Creating folder {output.parent}\")\n        makedirs(output.parent.as_posix())\n    #elif len(listdir(output.parent.as_posix())) > 0:\n    #    raise Exception(f\"Folder {output.parent.as_posix()} is not empty, aborting conversion\")\n\n    # Export the graph\n    if framework == \"pt\":\n        convert_pytorch(nlp, opset, output, use_external_format)\n    else:\n        convert_tensorflow(nlp, opset, output)\n        \n    # Save the configuration\n    if save_config:\n        config_path = os.path.splitext(output)[0] + '.json'\n\n        config = dict(\n            model = nlp.model.config.to_dict(),\n            tokenizer = nlp.tokenizer.init_kwargs\n        )\n        \n        #nlp.model.config.to_json_file(config_path)\n        \n        with open(config_path, 'w') as config_file:\n            json.dump(config, config_file, indent=2)\n            \n        print(f\"Saved config to {config_path}\")\n        \n\ndef optimize(onnx_model_path: Path) -> Path:\n    \"\"\"\n    Load the model at the specified path and let onnxruntime look at transformations on the graph to enable all the\n    optimizations possibl\n\n    Args:\n        onnx_model_path: filepath where the model binary description is stored\n\n    Returns: Path where the optimized model binary description has been saved\n\n    \"\"\"\n    from onnxruntime import InferenceSession, SessionOptions\n\n    # Generate model name with suffix \"optimized\"\n    opt_model_path = generate_identified_filename(onnx_model_path, \"-optimized\")\n    sess_option = SessionOptions()\n    sess_option.optimized_model_filepath = opt_model_path.as_posix()\n    _ = InferenceSession(onnx_model_path.as_posix(), sess_option)\n\n    print(f\"Optimized model has been written at {opt_model_path}: \\N{heavy check mark}\")\n    print(\"/!\\\\ Optimized model contains hardware specific operators which might not be portable. /!\\\\\")\n\n    return opt_model_path\n\n\ndef quantize(onnx_model_path: Path) -> Path:\n    \"\"\"\n    Quantize the weights of the model from float32 to in8 to allow very efficient inference on modern CPU\n\n    Args:\n        onnx_model_path: Path to location the exported ONNX model is stored\n\n    Returns: The Path generated for the quantized\n    \"\"\"\n    import onnx\n    from onnxruntime.quantization import QuantizationMode, quantize\n\n    onnx_model = onnx.load(onnx_model_path.as_posix())\n\n    # Discussed with @yufenglee from ONNX runtime, this will be address in the next release of onnxruntime\n    print(\n        \"As of onnxruntime 1.4.0, models larger than 2GB will fail to quantize due to protobuf constraint.\\n\"\n        \"This limitation will be removed in the next release of onnxruntime.\"\n    )\n\n    quantized_model = quantize(\n        model=onnx_model,\n        quantization_mode=QuantizationMode.IntegerOps,\n        force_fusions=True,\n        symmetric_weight=True,\n    )\n\n    # Append \"-quantized\" at the end of the model's name\n    quantized_model_path = generate_identified_filename(onnx_model_path, \"-quantized\")\n\n    # Save model\n    print(f\"Quantized model has been written at {quantized_model_path}: \\N{heavy check mark}\")\n    onnx.save_model(quantized_model, quantized_model_path.as_posix())\n\n    return quantized_model_path\n\n\ndef verify(path: Path):\n    from onnxruntime import InferenceSession, SessionOptions\n    from onnxruntime.capi.onnxruntime_pybind11_state import RuntimeException\n\n    print(f\"Checking ONNX model loading from: {path} ...\")\n    try:\n        onnx_options = SessionOptions()\n        _ = InferenceSession(path.as_posix(), onnx_options, providers=[\"CPUExecutionProvider\"])\n        print(f\"Model {path} correctly loaded: \\N{heavy check mark}\")\n    except RuntimeException as re:\n        print(f\"Error while loading the model {re}: \\N{heavy ballot x}\")\n\n\nif __name__ == \"__main__\":\n    parser = OnnxConverterArgumentParser()\n    args = parser.parse_args()\n\n    # Make sure output is absolute path\n    args.output = Path(args.output).absolute()\n\n    try:\n        print(\"\\n====== Converting model to ONNX ======\")\n        # Convert\n        convert(\n            args.framework,\n            args.model,\n            args.output,\n            args.opset,\n            args.tokenizer,\n            args.use_external_format,\n            args.pipeline,\n            args.save_config\n        )\n\n        if args.quantize:\n            # Ensure requirements for quantization on onnxruntime is met\n            check_onnxruntime_requirements(ORT_QUANTIZE_MINIMUM_VERSION)\n\n            # onnxruntime optimizations doesn't provide the same level of performances on TensorFlow than PyTorch\n            if args.framework == \"tf\":\n                print(\n                    \"\\t Using TensorFlow might not provide the same optimization level compared to PyTorch.\\n\"\n                    \"\\t For TensorFlow users you can try optimizing the model directly through onnxruntime_tools.\\n\"\n                    \"\\t For more information, please refer to the onnxruntime documentation:\\n\"\n                    \"\\t\\thttps://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/transformers\\n\"\n                )\n\n            print(\"\\n====== Optimizing ONNX model ======\")\n\n            # Quantization works best when using the optimized version of the model\n            args.optimized_output = optimize(args.output)\n\n            # Do the quantization on the right graph\n            args.quantized_output = quantize(args.optimized_output)\n\n        # And verify\n        if args.check_loading:\n            print(\"\\n====== Check exported ONNX model(s) ======\")\n            verify(args.output)\n\n            if hasattr(args, \"optimized_output\"):\n                verify(args.optimized_output)\n\n            if hasattr(args, \"quantized_output\"):\n                verify(args.quantized_output)\n\n    except Exception as e:\n        print(f\"Error while converting the model: {e}\")\n        exit(1)\n"
  },
  {
    "path": "patches/transformers/4.5.0/modeling_distilbert.py",
    "content": "# coding=utf-8\n# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\n PyTorch DistilBERT model adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM) and in\n part from HuggingFace PyTorch version of Google AI Bert model (https://github.com/google-research/bert)\n\"\"\"\n\n\nimport copy\nimport math\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom torch.nn import CrossEntropyLoss\n\nfrom ...activations import gelu\nfrom ...file_utils import (\n    add_code_sample_docstrings,\n    add_start_docstrings,\n    add_start_docstrings_to_model_forward,\n    replace_return_docstrings,\n)\nfrom ...modeling_outputs import (\n    BaseModelOutput,\n    MaskedLMOutput,\n    MultipleChoiceModelOutput,\n    QuestionAnsweringModelOutput,\n    SequenceClassifierOutput,\n    TokenClassifierOutput,\n)\nfrom ...modeling_utils import (\n    PreTrainedModel,\n    apply_chunking_to_forward,\n    find_pruneable_heads_and_indices,\n    prune_linear_layer,\n)\nfrom ...utils import logging\nfrom .configuration_distilbert import DistilBertConfig\n\n\nlogger = logging.get_logger(__name__)\n_CHECKPOINT_FOR_DOC = \"distilbert-base-uncased\"\n_CONFIG_FOR_DOC = \"DistilBertConfig\"\n_TOKENIZER_FOR_DOC = \"DistilBertTokenizer\"\n\nDISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [\n    \"distilbert-base-uncased\",\n    \"distilbert-base-uncased-distilled-squad\",\n    \"distilbert-base-cased\",\n    \"distilbert-base-cased-distilled-squad\",\n    \"distilbert-base-german-cased\",\n    \"distilbert-base-multilingual-cased\",\n    \"distilbert-base-uncased-finetuned-sst-2-english\",\n    # See all DistilBERT models at https://huggingface.co/models?filter=distilbert\n]\n\n\n# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #\n\n\ndef create_sinusoidal_embeddings(n_pos, dim, out):\n    position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])\n    out.requires_grad = False\n    out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))\n    out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))\n    out.detach_()\n\n\nclass Embeddings(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id)\n        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim)\n        if config.sinusoidal_pos_embds:\n            create_sinusoidal_embeddings(\n                n_pos=config.max_position_embeddings, dim=config.dim, out=self.position_embeddings.weight\n            )\n\n        self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)\n        self.dropout = nn.Dropout(config.dropout)\n\n    def forward(self, input_ids):\n        \"\"\"\n        Parameters:\n            input_ids: torch.tensor(bs, max_seq_length) The token ids to embed.\n\n        Returns: torch.tensor(bs, max_seq_length, dim) The embedded tokens (plus position embeddings, no token_type\n        embeddings)\n        \"\"\"\n        seq_length = input_ids.size(1)\n        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)  # (max_seq_length)\n        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)  # (bs, max_seq_length)\n\n        word_embeddings = self.word_embeddings(input_ids)  # (bs, max_seq_length, dim)\n        position_embeddings = self.position_embeddings(position_ids)  # (bs, max_seq_length, dim)\n\n        embeddings = word_embeddings + position_embeddings  # (bs, max_seq_length, dim)\n        embeddings = self.LayerNorm(embeddings)  # (bs, max_seq_length, dim)\n        embeddings = self.dropout(embeddings)  # (bs, max_seq_length, dim)\n        return embeddings\n\n\nclass MultiHeadSelfAttention(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n\n        self.n_heads = config.n_heads\n        self.dim = config.dim\n        self.dropout = nn.Dropout(p=config.attention_dropout)\n\n        assert self.dim % self.n_heads == 0\n\n        self.q_lin = nn.Linear(in_features=config.dim, out_features=config.dim)\n        self.k_lin = nn.Linear(in_features=config.dim, out_features=config.dim)\n        self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)\n        self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)\n\n        self.pruned_heads = set()\n\n    def prune_heads(self, heads):\n        attention_head_size = self.dim // self.n_heads\n        if len(heads) == 0:\n            return\n        heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, attention_head_size, self.pruned_heads)\n        # Prune linear layers\n        self.q_lin = prune_linear_layer(self.q_lin, index)\n        self.k_lin = prune_linear_layer(self.k_lin, index)\n        self.v_lin = prune_linear_layer(self.v_lin, index)\n        self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)\n        # Update hyper params\n        self.n_heads = self.n_heads - len(heads)\n        self.dim = attention_head_size * self.n_heads\n        self.pruned_heads = self.pruned_heads.union(heads)\n\n    def forward(self, query, key, value, mask, head_mask=None, output_attentions=False):\n        \"\"\"\n        Parameters:\n            query: torch.tensor(bs, seq_length, dim)\n            key: torch.tensor(bs, seq_length, dim)\n            value: torch.tensor(bs, seq_length, dim)\n            mask: torch.tensor(bs, seq_length)\n\n        Returns:\n            weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,\n            seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`\n        \"\"\"\n        bs, q_length, dim = query.size()\n        k_length = key.size(1)\n        # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured'\n        # assert key.size() == value.size()\n\n        dim_per_head = self.dim // self.n_heads\n\n        mask_reshp = (bs, 1, 1, k_length)\n\n        def shape(x):\n            \"\"\" separate heads \"\"\"\n            return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)\n\n        def unshape(x):\n            \"\"\" group heads \"\"\"\n            return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)\n\n        q = shape(self.q_lin(query))  # (bs, n_heads, q_length, dim_per_head)\n        k = shape(self.k_lin(key))  # (bs, n_heads, k_length, dim_per_head)\n        v = shape(self.v_lin(value))  # (bs, n_heads, k_length, dim_per_head)\n\n        q = q / math.sqrt(dim_per_head)  # (bs, n_heads, q_length, dim_per_head)\n        scores = torch.matmul(q, k.transpose(2, 3))  # (bs, n_heads, q_length, k_length)\n        mask = mask.view(mask_reshp).expand_as(scores)  # (bs, n_heads, q_length, k_length)\n        scores.masked_fill_((mask == 0), -float(\"inf\"))  # (bs, n_heads, q_length, k_length)\n\n        weights = nn.Softmax(dim=-1)(scores)  # (bs, n_heads, q_length, k_length)\n        weights = self.dropout(weights)  # (bs, n_heads, q_length, k_length)\n\n        # Mask heads if we want to\n        if head_mask is not None:\n            weights = weights * head_mask\n\n        context = torch.matmul(weights, v)  # (bs, n_heads, q_length, dim_per_head)\n        context = unshape(context)  # (bs, q_length, dim)\n        context = self.out_lin(context)  # (bs, q_length, dim)\n\n        if output_attentions:\n            return (context, weights)\n        else:\n            return (context,)\n\n\nclass FFN(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dropout = nn.Dropout(p=config.dropout)\n        self.chunk_size_feed_forward = config.chunk_size_feed_forward\n        self.seq_len_dim = 1\n        self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim)\n        self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim)\n        assert config.activation in [\"relu\", \"gelu\"], f\"activation ({config.activation}) must be in ['relu', 'gelu']\"\n        self.activation = gelu if config.activation == \"gelu\" else nn.ReLU()\n\n    def forward(self, input):\n        return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)\n\n    def ff_chunk(self, input):\n        x = self.lin1(input)\n        x = self.activation(x)\n        x = self.lin2(x)\n        x = self.dropout(x)\n        return x\n\n\nclass TransformerBlock(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n\n        assert config.dim % config.n_heads == 0\n\n        self.attention = MultiHeadSelfAttention(config)\n        self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)\n\n        self.ffn = FFN(config)\n        self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)\n\n    def forward(self, x, attn_mask=None, head_mask=None, output_attentions=False):\n        \"\"\"\n        Parameters:\n            x: torch.tensor(bs, seq_length, dim)\n            attn_mask: torch.tensor(bs, seq_length)\n\n        Returns:\n            sa_weights: torch.tensor(bs, n_heads, seq_length, seq_length) The attention weights ffn_output:\n            torch.tensor(bs, seq_length, dim) The output of the transformer block contextualization.\n        \"\"\"\n        # Self-Attention\n        sa_output = self.attention(\n            query=x,\n            key=x,\n            value=x,\n            mask=attn_mask,\n            head_mask=head_mask,\n            output_attentions=output_attentions,\n        )\n        if output_attentions:\n            sa_output, sa_weights = sa_output  # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)\n        else:  # To handle these `output_attentions` or `output_hidden_states` cases returning tuples\n            assert type(sa_output) == tuple\n            sa_output = sa_output[0]\n        sa_output = self.sa_layer_norm(sa_output + x)  # (bs, seq_length, dim)\n\n        # Feed Forward Network\n        ffn_output = self.ffn(sa_output)  # (bs, seq_length, dim)\n        ffn_output = self.output_layer_norm(ffn_output + sa_output)  # (bs, seq_length, dim)\n\n        output = (ffn_output,)\n        if output_attentions:\n            output = (sa_weights,) + output\n        return output\n\n\nclass Transformer(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.n_layers = config.n_layers\n\n        layer = TransformerBlock(config)\n        self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layers)])\n\n    def forward(\n        self, x, attn_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False, return_dict=None\n    ):  # docstyle-ignore\n        \"\"\"\n        Parameters:\n            x: torch.tensor(bs, seq_length, dim) Input sequence embedded.\n            attn_mask: torch.tensor(bs, seq_length) Attention mask on the sequence.\n\n        Returns:\n            hidden_state: torch.tensor(bs, seq_length, dim) Sequence of hidden states in the last (top)\n            layer all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]\n                Tuple of length n_layers with the hidden states from each layer.\n                Optional: only if output_hidden_states=True\n            all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]\n                Tuple of length n_layers with the attention weights from each layer\n                Optional: only if output_attentions=True\n        \"\"\"\n        all_hidden_states = () if output_hidden_states else None\n        all_attentions = () if output_attentions else None\n\n        hidden_state = x\n        for i, layer_module in enumerate(self.layer):\n            if output_hidden_states:\n                all_hidden_states = all_hidden_states + (hidden_state,)\n\n            layer_outputs = layer_module(\n                x=hidden_state, attn_mask=attn_mask, head_mask=head_mask[i], output_attentions=output_attentions\n            )\n            hidden_state = layer_outputs[-1]\n\n            if output_attentions:\n                assert len(layer_outputs) == 2\n                attentions = layer_outputs[0]\n                all_attentions = all_attentions + (attentions,)\n            else:\n                assert len(layer_outputs) == 1\n\n        # Add last layer\n        if output_hidden_states:\n            all_hidden_states = all_hidden_states + (hidden_state,)\n\n        if not return_dict:\n            return tuple(v for v in [hidden_state, all_hidden_states, all_attentions] if v is not None)\n        return BaseModelOutput(\n            last_hidden_state=hidden_state, hidden_states=all_hidden_states, attentions=all_attentions\n        )\n\n\n# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #\nclass DistilBertPreTrainedModel(PreTrainedModel):\n    \"\"\"\n    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained\n    models.\n    \"\"\"\n\n    config_class = DistilBertConfig\n    load_tf_weights = None\n    base_model_prefix = \"distilbert\"\n\n    def _init_weights(self, module):\n        \"\"\"Initialize the weights.\"\"\"\n        if isinstance(module, nn.Linear):\n            # Slightly different from the TF version which uses truncated_normal for initialization\n            # cf https://github.com/pytorch/pytorch/pull/5617\n            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n        elif isinstance(module, nn.LayerNorm):\n            module.bias.data.zero_()\n            module.weight.data.fill_(1.0)\n\n\nDISTILBERT_START_DOCSTRING = r\"\"\"\n\n    This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic\n    methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,\n    pruning heads etc.)\n\n    This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__\n    subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to\n    general usage and behavior.\n\n    Parameters:\n        config (:class:`~transformers.DistilBertConfig`): Model configuration class with all the parameters of the model.\n            Initializing with a config file does not load the weights associated with the model, only the\n            configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model\n            weights.\n\"\"\"\n\nDISTILBERT_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):\n            Indices of input sequence tokens in the vocabulary.\n\n            Indices can be obtained using :class:`~transformers.DistilBertTokenizer`. See\n            :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for\n            details.\n\n            `What are input IDs? <../glossary.html#input-ids>`__\n        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):\n            Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            `What are attention masks? <../glossary.html#attention-mask>`__\n        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):\n            Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n\n        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):\n            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.\n            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated\n            vectors than the model's internal embedding lookup matrix.\n        output_attentions (:obj:`bool`, `optional`):\n            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned\n            tensors for more detail.\n        output_hidden_states (:obj:`bool`, `optional`):\n            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for\n            more detail.\n        return_dict (:obj:`bool`, `optional`):\n            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare DistilBERT encoder/transformer outputting raw hidden-states without any specific head on top.\",\n    DISTILBERT_START_DOCSTRING,\n)\nclass DistilBertModel(DistilBertPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n\n        self.embeddings = Embeddings(config)  # Embeddings\n        self.transformer = Transformer(config)  # Encoder\n\n        self.init_weights()\n\n    def get_input_embeddings(self):\n        return self.embeddings.word_embeddings\n\n    def set_input_embeddings(self, new_embeddings):\n        self.embeddings.word_embeddings = new_embeddings\n\n    def _prune_heads(self, heads_to_prune):\n        \"\"\"\n        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base\n        class PreTrainedModel\n        \"\"\"\n        for layer, heads in heads_to_prune.items():\n            self.transformer.layer[layer].attention.prune_heads(heads)\n\n    @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format(\"batch_size, num_choices\"))\n    @add_code_sample_docstrings(\n        tokenizer_class=_TOKENIZER_FOR_DOC,\n        checkpoint=_CHECKPOINT_FOR_DOC,\n        output_type=BaseModelOutput,\n        config_class=_CONFIG_FOR_DOC,\n    )\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        head_mask=None,\n        inputs_embeds=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ):\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\"You cannot specify both input_ids and inputs_embeds at the same time\")\n        elif input_ids is not None:\n            input_shape = input_ids.size()\n        elif inputs_embeds is not None:\n            input_shape = inputs_embeds.size()[:-1]\n        else:\n            raise ValueError(\"You have to specify either input_ids or inputs_embeds\")\n\n        device = input_ids.device if input_ids is not None else inputs_embeds.device\n\n        if attention_mask is None:\n            attention_mask = torch.ones(input_shape, device=device)  # (bs, seq_length)\n\n        # Prepare head mask if needed\n        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embeddings(input_ids)  # (bs, seq_length, dim)\n        return self.transformer(\n            x=inputs_embeds,\n            attn_mask=attention_mask,\n            head_mask=head_mask,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"DistilBert Model with a `masked language modeling` head on top. \"\"\",\n    DISTILBERT_START_DOCSTRING,\n)\nclass DistilBertForMaskedLM(DistilBertPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n\n        self.distilbert = DistilBertModel(config)\n        self.vocab_transform = nn.Linear(config.dim, config.dim)\n        self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)\n        self.vocab_projector = nn.Linear(config.dim, config.vocab_size)\n\n        self.init_weights()\n\n        self.mlm_loss_fct = nn.CrossEntropyLoss()\n\n    def get_output_embeddings(self):\n        return self.vocab_projector\n\n    def set_output_embeddings(self, new_embeddings):\n        self.vocab_projector = new_embeddings\n\n    @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format(\"batch_size, num_choices\"))\n    @add_code_sample_docstrings(\n        tokenizer_class=_TOKENIZER_FOR_DOC,\n        checkpoint=_CHECKPOINT_FOR_DOC,\n        output_type=MaskedLMOutput,\n        config_class=_CONFIG_FOR_DOC,\n    )\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        head_mask=None,\n        inputs_embeds=None,\n        labels=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ):\n        r\"\"\"\n        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,\n            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored\n            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``.\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        dlbrt_output = self.distilbert(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = dlbrt_output[0]  # (bs, seq_length, dim)\n        prediction_logits = self.vocab_transform(hidden_states)  # (bs, seq_length, dim)\n        prediction_logits = gelu(prediction_logits)  # (bs, seq_length, dim)\n        prediction_logits = self.vocab_layer_norm(prediction_logits)  # (bs, seq_length, dim)\n        prediction_logits = self.vocab_projector(prediction_logits)  # (bs, seq_length, vocab_size)\n\n        mlm_loss = None\n        if labels is not None:\n            mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)), labels.view(-1))\n\n        if not return_dict:\n            output = (prediction_logits,) + dlbrt_output[1:]\n            return ((mlm_loss,) + output) if mlm_loss is not None else output\n\n        return MaskedLMOutput(\n            loss=mlm_loss,\n            logits=prediction_logits,\n            hidden_states=dlbrt_output.hidden_states,\n            attentions=dlbrt_output.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\n    DistilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the\n    pooled output) e.g. for GLUE tasks.\n    \"\"\",\n    DISTILBERT_START_DOCSTRING,\n)\nclass DistilBertForSequenceClassification(DistilBertPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n\n        self.distilbert = DistilBertModel(config)\n        self.pre_classifier = nn.Linear(config.dim, config.dim)\n        self.classifier = nn.Linear(config.dim, config.num_labels)\n        self.dropout = nn.Dropout(config.seq_classif_dropout)\n\n        self.init_weights()\n\n    @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format(\"batch_size, num_choices\"))\n    @add_code_sample_docstrings(\n        tokenizer_class=_TOKENIZER_FOR_DOC,\n        checkpoint=_CHECKPOINT_FOR_DOC,\n        output_type=SequenceClassifierOutput,\n        config_class=_CONFIG_FOR_DOC,\n    )\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        head_mask=None,\n        inputs_embeds=None,\n        labels=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ):\n        r\"\"\"\n        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):\n            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,\n            config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),\n            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        distilbert_output = self.distilbert(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_state = distilbert_output[0]  # (bs, seq_len, dim)\n        pooled_output = hidden_state[:, 0]  # (bs, dim)\n        pooled_output = self.pre_classifier(pooled_output)  # (bs, dim)\n        pooled_output = nn.ReLU()(pooled_output)  # (bs, dim)\n        pooled_output = self.dropout(pooled_output)  # (bs, dim)\n        logits = self.classifier(pooled_output)  # (bs, num_labels)\n\n        loss = None\n        if labels is not None:\n            if self.num_labels == 1:\n                loss_fct = nn.MSELoss()\n                loss = loss_fct(logits.view(-1), labels.view(-1))\n            else:\n                loss_fct = nn.CrossEntropyLoss()\n                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))\n\n        if not return_dict:\n            output = (logits,) + distilbert_output[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return SequenceClassifierOutput(\n            loss=loss,\n            logits=logits,\n            hidden_states=distilbert_output.hidden_states,\n            attentions=distilbert_output.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\n    DistilBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a\n    linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).\n    \"\"\",\n    DISTILBERT_START_DOCSTRING,\n)\nclass DistilBertForQuestionAnswering(DistilBertPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n\n        self.distilbert = DistilBertModel(config)\n        self.qa_outputs = nn.Linear(config.dim, config.num_labels)\n        assert config.num_labels == 2\n        self.dropout = nn.Dropout(config.qa_dropout)\n\n        self.init_weights()\n\n    @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format(\"batch_size, num_choices\"))\n    @add_code_sample_docstrings(\n        tokenizer_class=_TOKENIZER_FOR_DOC,\n        checkpoint=_CHECKPOINT_FOR_DOC,\n        output_type=QuestionAnsweringModelOutput,\n        config_class=_CONFIG_FOR_DOC,\n    )\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        head_mask=None,\n        inputs_embeds=None,\n        start_positions=None,\n        end_positions=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ):\n        r\"\"\"\n        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):\n            Labels for position (index) of the start of the labelled span for computing the token classification loss.\n            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the\n            sequence are not taken into account for computing the loss.\n        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):\n            Labels for position (index) of the end of the labelled span for computing the token classification loss.\n            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the\n            sequence are not taken into account for computing the loss.\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        distilbert_output = self.distilbert(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = distilbert_output[0]  # (bs, max_query_len, dim)\n\n        hidden_states = self.dropout(hidden_states)  # (bs, max_query_len, dim)\n        logits = self.qa_outputs(hidden_states)  # (bs, max_query_len, 2)\n        start_logits, end_logits = logits.split(1, dim=-1)\n        start_logits = start_logits.squeeze(-1)  # (bs, max_query_len)\n        end_logits = end_logits.squeeze(-1)  # (bs, max_query_len)\n\n        total_loss = None\n        if start_positions is not None and end_positions is not None:\n            # If we are on multi-GPU, split add a dimension\n            if len(start_positions.size()) > 1:\n                start_positions = start_positions.squeeze(-1)\n            if len(end_positions.size()) > 1:\n                end_positions = end_positions.squeeze(-1)\n            # sometimes the start/end positions are outside our model inputs, we ignore these terms\n            ignored_index = start_logits.size(1)\n            start_positions.clamp_(0, ignored_index)\n            end_positions.clamp_(0, ignored_index)\n\n            loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)\n            start_loss = loss_fct(start_logits, start_positions)\n            end_loss = loss_fct(end_logits, end_positions)\n            total_loss = (start_loss + end_loss) / 2\n\n        if not return_dict:\n            output = (start_logits, end_logits) + distilbert_output[1:]\n            return ((total_loss,) + output) if total_loss is not None else output\n\n        return QuestionAnsweringModelOutput(\n            loss=total_loss,\n            start_logits=start_logits,\n            end_logits=end_logits,\n            hidden_states=distilbert_output.hidden_states,\n            attentions=distilbert_output.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\n    DistilBert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.\n    for Named-Entity-Recognition (NER) tasks.\n    \"\"\",\n    DISTILBERT_START_DOCSTRING,\n)\nclass DistilBertForTokenClassification(DistilBertPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n\n        self.distilbert = DistilBertModel(config)\n        self.dropout = nn.Dropout(config.dropout)\n        self.classifier = nn.Linear(config.hidden_size, config.num_labels)\n\n        self.init_weights()\n\n    @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING)\n    @add_code_sample_docstrings(\n        tokenizer_class=_TOKENIZER_FOR_DOC,\n        checkpoint=_CHECKPOINT_FOR_DOC,\n        output_type=TokenClassifierOutput,\n        config_class=_CONFIG_FOR_DOC,\n    )\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        head_mask=None,\n        inputs_embeds=None,\n        labels=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ):\n        r\"\"\"\n        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n            Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -\n            1]``.\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        outputs = self.distilbert(\n            input_ids,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        sequence_output = outputs[0]\n\n        sequence_output = self.dropout(sequence_output)\n        logits = self.classifier(sequence_output)\n\n        loss = None\n        if labels is not None:\n            loss_fct = CrossEntropyLoss()\n            # Only keep active parts of the loss\n            if attention_mask is not None:\n                active_loss = attention_mask.view(-1) == 1\n                active_logits = logits.view(-1, self.num_labels)\n                active_labels = torch.where(\n                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)\n                )\n                loss = loss_fct(active_logits, active_labels)\n            else:\n                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return TokenClassifierOutput(\n            loss=loss,\n            logits=logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\n    DistilBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and\n    a softmax) e.g. for RocStories/SWAG tasks.\n    \"\"\",\n    DISTILBERT_START_DOCSTRING,\n)\nclass DistilBertForMultipleChoice(DistilBertPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n\n        self.distilbert = DistilBertModel(config)\n        self.pre_classifier = nn.Linear(config.dim, config.dim)\n        self.classifier = nn.Linear(config.dim, 1)\n        self.dropout = nn.Dropout(config.seq_classif_dropout)\n\n        self.init_weights()\n\n    @add_start_docstrings_to_model_forward(\n        DISTILBERT_INPUTS_DOCSTRING.format(\"batch_size, num_choices, sequence_length\")\n    )\n    @replace_return_docstrings(output_type=MultipleChoiceModelOutput, config_class=_CONFIG_FOR_DOC)\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        head_mask=None,\n        inputs_embeds=None,\n        labels=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ):\n        r\"\"\"\n        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):\n            Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,\n            num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See\n            :obj:`input_ids` above)\n\n        Returns:\n\n        Examples::\n\n            >>> from transformers import DistilBertTokenizer, DistilBertForMultipleChoice\n            >>> import torch\n\n            >>> tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')\n            >>> model = DistilBertForMultipleChoice.from_pretrained('distilbert-base-cased')\n\n            >>> prompt = \"In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced.\"\n            >>> choice0 = \"It is eaten with a fork and a knife.\"\n            >>> choice1 = \"It is eaten while held in the hand.\"\n            >>> labels = torch.tensor(0).unsqueeze(0)  # choice0 is correct (according to Wikipedia ;)), batch size 1\n\n            >>> encoding = tokenizer([[prompt, choice0], [prompt, choice1]], return_tensors='pt', padding=True)\n            >>> outputs = model(**{k: v.unsqueeze(0) for k,v in encoding.items()}, labels=labels) # batch size is 1\n\n            >>> # the linear classifier still needs to be trained\n            >>> loss = outputs.loss\n            >>> logits = outputs.logits\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n        num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]\n\n        input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None\n        attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None\n        inputs_embeds = (\n            inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))\n            if inputs_embeds is not None\n            else None\n        )\n\n        outputs = self.distilbert(\n            input_ids,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        hidden_state = outputs[0]  # (bs * num_choices, seq_len, dim)\n        pooled_output = hidden_state[:, 0]  # (bs * num_choices, dim)\n        pooled_output = self.pre_classifier(pooled_output)  # (bs * num_choices, dim)\n        pooled_output = nn.ReLU()(pooled_output)  # (bs * num_choices, dim)\n        pooled_output = self.dropout(pooled_output)  # (bs * num_choices, dim)\n        logits = self.classifier(pooled_output)  # (bs * num_choices, 1)\n\n        reshaped_logits = logits.view(-1, num_choices)  # (bs, num_choices)\n\n        loss = None\n        if labels is not None:\n            loss_fct = CrossEntropyLoss()\n            loss = loss_fct(reshaped_logits, labels)\n\n        if not return_dict:\n            output = (reshaped_logits,) + outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return MultipleChoiceModelOutput(\n            loss=loss,\n            logits=reshaped_logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n"
  },
  {
    "path": "patches/transformers/4.5.1/convert_graph_to_onnx.diff",
    "content": "14a15,17\n> import os \n> import json\n> \n83a87,91\n>             \"--save-config\",\n>             action=\"store_true\",\n>             help=\"Save the model configuration along with the ONNX\",\n>         )\n>         self.add_argument(\n280a289,295\n>         print('Exporting from PyTorch to ONNX...')\n>         print('input_names', input_names)\n>         print('output_names', output_names)\n>         print('dynamic_axes', dynamic_axes)\n>         print('tokens', tokens)\n>         print('model_args', model_args)\n>         \n291a307\n>             verbose=True\n339a356\n>     save_config: bool = False,\n366,367c383,384\n<     elif len(listdir(output.parent.as_posix())) > 0:\n<         raise Exception(f\"Folder {output.parent.as_posix()} is not empty, aborting conversion\")\n---\n>     #elif len(listdir(output.parent.as_posix())) > 0:\n>     #    raise Exception(f\"Folder {output.parent.as_posix()} is not empty, aborting conversion\")\n374c391,407\n< \n---\n>         \n>     # Save the configuration\n>     if save_config:\n>         config_path = os.path.splitext(output)[0] + '.json'\n> \n>         config = dict(\n>             model = nlp.model.config.to_dict(),\n>             tokenizer = nlp.tokenizer.init_kwargs\n>         )\n>         \n>         #nlp.model.config.to_json_file(config_path)\n>         \n>         with open(config_path, 'w') as config_file:\n>             json.dump(config, config_file, indent=2)\n>             \n>         print(f\"Saved config to {config_path}\")\n>         \n468a502\n>             args.save_config\n"
  },
  {
    "path": "patches/transformers/4.5.1/convert_graph_to_onnx.original.py",
    "content": "# Copyright 2020 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom argparse import ArgumentParser\nfrom os import listdir, makedirs\nfrom pathlib import Path\nfrom typing import Dict, List, Optional, Tuple\n\nfrom packaging.version import Version, parse\n\nfrom transformers.file_utils import ModelOutput, is_tf_available, is_torch_available\nfrom transformers.pipelines import Pipeline, pipeline\nfrom transformers.tokenization_utils import BatchEncoding\n\n\n# This is the minimal required version to\n# support some ONNX Runtime features\nORT_QUANTIZE_MINIMUM_VERSION = parse(\"1.4.0\")\n\n\nSUPPORTED_PIPELINES = [\n    \"feature-extraction\",\n    \"ner\",\n    \"sentiment-analysis\",\n    \"fill-mask\",\n    \"question-answering\",\n    \"text-generation\",\n    \"translation_en_to_fr\",\n    \"translation_en_to_de\",\n    \"translation_en_to_ro\",\n]\n\n\nclass OnnxConverterArgumentParser(ArgumentParser):\n    \"\"\"\n    Wraps all the script arguments supported to export transformers models to ONNX IR\n    \"\"\"\n\n    def __init__(self):\n        super().__init__(\"ONNX Converter\")\n\n        self.add_argument(\n            \"--pipeline\",\n            type=str,\n            choices=SUPPORTED_PIPELINES,\n            default=\"feature-extraction\",\n        )\n        self.add_argument(\n            \"--model\",\n            type=str,\n            required=True,\n            help=\"Model's id or path (ex: bert-base-cased)\",\n        )\n        self.add_argument(\"--tokenizer\", type=str, help=\"Tokenizer's id or path (ex: bert-base-cased)\")\n        self.add_argument(\n            \"--framework\",\n            type=str,\n            choices=[\"pt\", \"tf\"],\n            help=\"Framework for loading the model\",\n        )\n        self.add_argument(\"--opset\", type=int, default=11, help=\"ONNX opset to use\")\n        self.add_argument(\n            \"--check-loading\",\n            action=\"store_true\",\n            help=\"Check ONNX is able to load the model\",\n        )\n        self.add_argument(\n            \"--use-external-format\",\n            action=\"store_true\",\n            help=\"Allow exporting model >= than 2Gb\",\n        )\n        self.add_argument(\n            \"--quantize\",\n            action=\"store_true\",\n            help=\"Quantize the neural network to be run with int8\",\n        )\n        self.add_argument(\"output\")\n\n\ndef generate_identified_filename(filename: Path, identifier: str) -> Path:\n    \"\"\"\n    Append a string-identifier at the end (before the extension, if any) to the provided filepath\n\n    Args:\n        filename: pathlib.Path The actual path object we would like to add an identifier suffix\n        identifier: The suffix to add\n\n    Returns: String with concatenated identifier at the end of the filename\n    \"\"\"\n    return filename.parent.joinpath(filename.stem + identifier).with_suffix(filename.suffix)\n\n\ndef check_onnxruntime_requirements(minimum_version: Version):\n    \"\"\"\n    Check onnxruntime is installed and if the installed version match is recent enough\n\n    Raises:\n        ImportError: If onnxruntime is not installed or too old version is found\n    \"\"\"\n    try:\n        import onnxruntime\n\n        # Parse the version of the installed onnxruntime\n        ort_version = parse(onnxruntime.__version__)\n\n        # We require 1.4.0 minimum\n        if ort_version < ORT_QUANTIZE_MINIMUM_VERSION:\n            raise ImportError(\n                f\"We found an older version of onnxruntime ({onnxruntime.__version__}) \"\n                f\"but we require onnxruntime to be >= {minimum_version} to enable all the conversions options.\\n\"\n                f\"Please update onnxruntime by running `pip install --upgrade onnxruntime`\"\n            )\n\n    except ImportError:\n        raise ImportError(\n            \"onnxruntime doesn't seem to be currently installed. \"\n            \"Please install the onnxruntime by running `pip install onnxruntime`\"\n            \" and relaunch the conversion.\"\n        )\n\n\ndef ensure_valid_input(model, tokens, input_names):\n    \"\"\"\n    Ensure input are presented in the correct order, without any Non\n\n    Args:\n        model: The model used to forward the input data\n        tokens: BatchEncoding holding the input data\n        input_names: The name of the inputs\n\n    Returns: Tuple\n\n    \"\"\"\n    print(\"Ensuring inputs are in correct order\")\n\n    model_args_name = model.forward.__code__.co_varnames\n    model_args, ordered_input_names = [], []\n    for arg_name in model_args_name[1:]:  # start at index 1 to skip \"self\" argument\n        if arg_name in input_names:\n            ordered_input_names.append(arg_name)\n            model_args.append(tokens[arg_name])\n        else:\n            print(f\"{arg_name} is not present in the generated input list.\")\n            break\n\n    print(f\"Generated inputs order: {ordered_input_names}\")\n    return ordered_input_names, tuple(model_args)\n\n\ndef infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], Dict, BatchEncoding]:\n    \"\"\"\n    Attempt to infer the static vs dynamic axes for each input and output tensors for a specific model\n\n    Args:\n        nlp: The pipeline object holding the model to be exported\n        framework: The framework identifier to dispatch to the correct inference scheme (pt/tf)\n\n    Returns:\n\n        - List of the inferred input variable names\n        - List of the inferred output variable names\n        - Dictionary with input/output variables names as key and shape tensor as value\n        - a BatchEncoding reference which was used to infer all the above information\n    \"\"\"\n\n    def build_shape_dict(name: str, tensor, is_input: bool, seq_len: int):\n        if isinstance(tensor, (tuple, list)):\n            return [build_shape_dict(name, t, is_input, seq_len) for t in tensor]\n\n        else:\n            # Let's assume batch is the first axis with only 1 element (~~ might not be always true ...)\n            axes = {[axis for axis, numel in enumerate(tensor.shape) if numel == 1][0]: \"batch\"}\n            if is_input:\n                if len(tensor.shape) == 2:\n                    axes[1] = \"sequence\"\n                else:\n                    raise ValueError(f\"Unable to infer tensor axes ({len(tensor.shape)})\")\n            else:\n                seq_axes = [dim for dim, shape in enumerate(tensor.shape) if shape == seq_len]\n                axes.update({dim: \"sequence\" for dim in seq_axes})\n\n        print(f\"Found {'input' if is_input else 'output'} {name} with shape: {axes}\")\n        return axes\n\n    tokens = nlp.tokenizer(\"This is a sample output\", return_tensors=framework)\n    seq_len = tokens.input_ids.shape[-1]\n    outputs = nlp.model(**tokens) if framework == \"pt\" else nlp.model(tokens)\n    if isinstance(outputs, ModelOutput):\n        outputs = outputs.to_tuple()\n    if not isinstance(outputs, (list, tuple)):\n        outputs = (outputs,)\n\n    # Generate input names & axes\n    input_vars = list(tokens.keys())\n    input_dynamic_axes = {k: build_shape_dict(k, v, True, seq_len) for k, v in tokens.items()}\n\n    # flatten potentially grouped outputs (past for gpt2, attentions)\n    outputs_flat = []\n    for output in outputs:\n        if isinstance(output, (tuple, list)):\n            outputs_flat.extend(output)\n        else:\n            outputs_flat.append(output)\n\n    # Generate output names & axes\n    output_names = [f\"output_{i}\" for i in range(len(outputs_flat))]\n    output_dynamic_axes = {k: build_shape_dict(k, v, False, seq_len) for k, v in zip(output_names, outputs_flat)}\n\n    # Create the aggregated axes representation\n    dynamic_axes = dict(input_dynamic_axes, **output_dynamic_axes)\n    return input_vars, output_names, dynamic_axes, tokens\n\n\ndef load_graph_from_args(\n    pipeline_name: str, framework: str, model: str, tokenizer: Optional[str] = None, **models_kwargs\n) -> Pipeline:\n    \"\"\"\n    Convert the set of arguments provided through the CLI to an actual pipeline reference (tokenizer + model\n\n    Args:\n        pipeline_name: The kind of pipeline to use (ner, question-answering, etc.)\n        framework: The actual model to convert the pipeline from (\"pt\" or \"tf\")\n        model: The model name which will be loaded by the pipeline\n        tokenizer: The tokenizer name which will be loaded by the pipeline, default to the model's value\n\n    Returns: Pipeline object\n\n    \"\"\"\n    # If no tokenizer provided\n    if tokenizer is None:\n        tokenizer = model\n\n    # Check the wanted framework is available\n    if framework == \"pt\" and not is_torch_available():\n        raise Exception(\"Cannot convert because PyTorch is not installed. Please install torch first.\")\n    if framework == \"tf\" and not is_tf_available():\n        raise Exception(\"Cannot convert because TF is not installed. Please install tensorflow first.\")\n\n    print(f\"Loading pipeline (model: {model}, tokenizer: {tokenizer})\")\n\n    # Allocate tokenizer and model\n    return pipeline(pipeline_name, model=model, tokenizer=tokenizer, framework=framework, model_kwargs=models_kwargs)\n\n\ndef convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format: bool):\n    \"\"\"\n    Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR\n\n    Args:\n        nlp: The pipeline to be exported\n        opset: The actual version of the ONNX operator set to use\n        output: Path where will be stored the generated ONNX model\n        use_external_format: Split the model definition from its parameters to allow model bigger than 2GB\n\n    Returns:\n\n    \"\"\"\n    if not is_torch_available():\n        raise Exception(\"Cannot convert because PyTorch is not installed. Please install torch first.\")\n\n    import torch\n    from torch.onnx import export\n\n    print(f\"Using framework PyTorch: {torch.__version__}\")\n\n    with torch.no_grad():\n        input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, \"pt\")\n        ordered_input_names, model_args = ensure_valid_input(nlp.model, tokens, input_names)\n\n        export(\n            nlp.model,\n            model_args,\n            f=output.as_posix(),\n            input_names=ordered_input_names,\n            output_names=output_names,\n            dynamic_axes=dynamic_axes,\n            do_constant_folding=True,\n            use_external_data_format=use_external_format,\n            enable_onnx_checker=True,\n            opset_version=opset,\n        )\n\n\ndef convert_tensorflow(nlp: Pipeline, opset: int, output: Path):\n    \"\"\"\n    Export a TensorFlow backed pipeline to ONNX Intermediate Representation (IR\n\n    Args:\n        nlp: The pipeline to be exported\n        opset: The actual version of the ONNX operator set to use\n        output: Path where will be stored the generated ONNX model\n\n    Notes: TensorFlow cannot export model bigger than 2GB due to internal constraint from TensorFlow\n\n    \"\"\"\n    if not is_tf_available():\n        raise Exception(\"Cannot convert because TF is not installed. Please install tensorflow first.\")\n\n    print(\"/!\\\\ Please note TensorFlow doesn't support exporting model > 2Gb /!\\\\\")\n\n    try:\n        import tensorflow as tf\n\n        from keras2onnx import __version__ as k2ov\n        from keras2onnx import convert_keras, save_model\n\n        print(f\"Using framework TensorFlow: {tf.version.VERSION}, keras2onnx: {k2ov}\")\n\n        # Build\n        input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, \"tf\")\n\n        # Forward\n        nlp.model.predict(tokens.data)\n        onnx_model = convert_keras(nlp.model, nlp.model.name, target_opset=opset)\n        save_model(onnx_model, output.as_posix())\n\n    except ImportError as e:\n        raise Exception(f\"Cannot import {e.name} required to convert TF model to ONNX. Please install {e.name} first.\")\n\n\ndef convert(\n    framework: str,\n    model: str,\n    output: Path,\n    opset: int,\n    tokenizer: Optional[str] = None,\n    use_external_format: bool = False,\n    pipeline_name: str = \"feature-extraction\",\n    **model_kwargs\n):\n    \"\"\"\n    Convert the pipeline object to the ONNX Intermediate Representation (IR) format\n\n    Args:\n        framework: The framework the pipeline is backed by (\"pt\" or \"tf\")\n        model: The name of the model to load for the pipeline\n        output: The path where the ONNX graph will be stored\n        opset: The actual version of the ONNX operator set to use\n        tokenizer: The name of the model to load for the pipeline, default to the model's name if not provided\n        use_external_format: Split the model definition from its parameters to allow model bigger than 2GB (PyTorch only)\n        pipeline_name: The kind of pipeline to instantiate (ner, question-answering, etc.)\n        model_kwargs: Keyword arguments to be forwarded to the model constructor\n\n    Returns:\n\n    \"\"\"\n    print(f\"ONNX opset version set to: {opset}\")\n\n    # Load the pipeline\n    nlp = load_graph_from_args(pipeline_name, framework, model, tokenizer, **model_kwargs)\n\n    if not output.parent.exists():\n        print(f\"Creating folder {output.parent}\")\n        makedirs(output.parent.as_posix())\n    elif len(listdir(output.parent.as_posix())) > 0:\n        raise Exception(f\"Folder {output.parent.as_posix()} is not empty, aborting conversion\")\n\n    # Export the graph\n    if framework == \"pt\":\n        convert_pytorch(nlp, opset, output, use_external_format)\n    else:\n        convert_tensorflow(nlp, opset, output)\n\n\ndef optimize(onnx_model_path: Path) -> Path:\n    \"\"\"\n    Load the model at the specified path and let onnxruntime look at transformations on the graph to enable all the\n    optimizations possibl\n\n    Args:\n        onnx_model_path: filepath where the model binary description is stored\n\n    Returns: Path where the optimized model binary description has been saved\n\n    \"\"\"\n    from onnxruntime import InferenceSession, SessionOptions\n\n    # Generate model name with suffix \"optimized\"\n    opt_model_path = generate_identified_filename(onnx_model_path, \"-optimized\")\n    sess_option = SessionOptions()\n    sess_option.optimized_model_filepath = opt_model_path.as_posix()\n    _ = InferenceSession(onnx_model_path.as_posix(), sess_option)\n\n    print(f\"Optimized model has been written at {opt_model_path}: \\N{heavy check mark}\")\n    print(\"/!\\\\ Optimized model contains hardware specific operators which might not be portable. /!\\\\\")\n\n    return opt_model_path\n\n\ndef quantize(onnx_model_path: Path) -> Path:\n    \"\"\"\n    Quantize the weights of the model from float32 to in8 to allow very efficient inference on modern CPU\n\n    Args:\n        onnx_model_path: Path to location the exported ONNX model is stored\n\n    Returns: The Path generated for the quantized\n    \"\"\"\n    import onnx\n    from onnxruntime.quantization import QuantizationMode, quantize\n\n    onnx_model = onnx.load(onnx_model_path.as_posix())\n\n    # Discussed with @yufenglee from ONNX runtime, this will be address in the next release of onnxruntime\n    print(\n        \"As of onnxruntime 1.4.0, models larger than 2GB will fail to quantize due to protobuf constraint.\\n\"\n        \"This limitation will be removed in the next release of onnxruntime.\"\n    )\n\n    quantized_model = quantize(\n        model=onnx_model,\n        quantization_mode=QuantizationMode.IntegerOps,\n        force_fusions=True,\n        symmetric_weight=True,\n    )\n\n    # Append \"-quantized\" at the end of the model's name\n    quantized_model_path = generate_identified_filename(onnx_model_path, \"-quantized\")\n\n    # Save model\n    print(f\"Quantized model has been written at {quantized_model_path}: \\N{heavy check mark}\")\n    onnx.save_model(quantized_model, quantized_model_path.as_posix())\n\n    return quantized_model_path\n\n\ndef verify(path: Path):\n    from onnxruntime import InferenceSession, SessionOptions\n    from onnxruntime.capi.onnxruntime_pybind11_state import RuntimeException\n\n    print(f\"Checking ONNX model loading from: {path} ...\")\n    try:\n        onnx_options = SessionOptions()\n        _ = InferenceSession(path.as_posix(), onnx_options, providers=[\"CPUExecutionProvider\"])\n        print(f\"Model {path} correctly loaded: \\N{heavy check mark}\")\n    except RuntimeException as re:\n        print(f\"Error while loading the model {re}: \\N{heavy ballot x}\")\n\n\nif __name__ == \"__main__\":\n    parser = OnnxConverterArgumentParser()\n    args = parser.parse_args()\n\n    # Make sure output is absolute path\n    args.output = Path(args.output).absolute()\n\n    try:\n        print(\"\\n====== Converting model to ONNX ======\")\n        # Convert\n        convert(\n            args.framework,\n            args.model,\n            args.output,\n            args.opset,\n            args.tokenizer,\n            args.use_external_format,\n            args.pipeline,\n        )\n\n        if args.quantize:\n            # Ensure requirements for quantization on onnxruntime is met\n            check_onnxruntime_requirements(ORT_QUANTIZE_MINIMUM_VERSION)\n\n            # onnxruntime optimizations doesn't provide the same level of performances on TensorFlow than PyTorch\n            if args.framework == \"tf\":\n                print(\n                    \"\\t Using TensorFlow might not provide the same optimization level compared to PyTorch.\\n\"\n                    \"\\t For TensorFlow users you can try optimizing the model directly through onnxruntime_tools.\\n\"\n                    \"\\t For more information, please refer to the onnxruntime documentation:\\n\"\n                    \"\\t\\thttps://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/transformers\\n\"\n                )\n\n            print(\"\\n====== Optimizing ONNX model ======\")\n\n            # Quantization works best when using the optimized version of the model\n            args.optimized_output = optimize(args.output)\n\n            # Do the quantization on the right graph\n            args.quantized_output = quantize(args.optimized_output)\n\n        # And verify\n        if args.check_loading:\n            print(\"\\n====== Check exported ONNX model(s) ======\")\n            verify(args.output)\n\n            if hasattr(args, \"optimized_output\"):\n                verify(args.optimized_output)\n\n            if hasattr(args, \"quantized_output\"):\n                verify(args.quantized_output)\n\n    except Exception as e:\n        print(f\"Error while converting the model: {e}\")\n        exit(1)\n"
  },
  {
    "path": "patches/transformers/4.5.1/convert_graph_to_onnx.py",
    "content": "# Copyright 2020 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os \nimport json\n\nfrom argparse import ArgumentParser\nfrom os import listdir, makedirs\nfrom pathlib import Path\nfrom typing import Dict, List, Optional, Tuple\n\nfrom packaging.version import Version, parse\n\nfrom transformers.file_utils import ModelOutput, is_tf_available, is_torch_available\nfrom transformers.pipelines import Pipeline, pipeline\nfrom transformers.tokenization_utils import BatchEncoding\n\n\n# This is the minimal required version to\n# support some ONNX Runtime features\nORT_QUANTIZE_MINIMUM_VERSION = parse(\"1.4.0\")\n\n\nSUPPORTED_PIPELINES = [\n    \"feature-extraction\",\n    \"ner\",\n    \"sentiment-analysis\",\n    \"fill-mask\",\n    \"question-answering\",\n    \"text-generation\",\n    \"translation_en_to_fr\",\n    \"translation_en_to_de\",\n    \"translation_en_to_ro\",\n]\n\n\nclass OnnxConverterArgumentParser(ArgumentParser):\n    \"\"\"\n    Wraps all the script arguments supported to export transformers models to ONNX IR\n    \"\"\"\n\n    def __init__(self):\n        super().__init__(\"ONNX Converter\")\n\n        self.add_argument(\n            \"--pipeline\",\n            type=str,\n            choices=SUPPORTED_PIPELINES,\n            default=\"feature-extraction\",\n        )\n        self.add_argument(\n            \"--model\",\n            type=str,\n            required=True,\n            help=\"Model's id or path (ex: bert-base-cased)\",\n        )\n        self.add_argument(\"--tokenizer\", type=str, help=\"Tokenizer's id or path (ex: bert-base-cased)\")\n        self.add_argument(\n            \"--framework\",\n            type=str,\n            choices=[\"pt\", \"tf\"],\n            help=\"Framework for loading the model\",\n        )\n        self.add_argument(\"--opset\", type=int, default=11, help=\"ONNX opset to use\")\n        self.add_argument(\n            \"--check-loading\",\n            action=\"store_true\",\n            help=\"Check ONNX is able to load the model\",\n        )\n        self.add_argument(\n            \"--use-external-format\",\n            action=\"store_true\",\n            help=\"Allow exporting model >= than 2Gb\",\n        )\n        self.add_argument(\n            \"--save-config\",\n            action=\"store_true\",\n            help=\"Save the model configuration along with the ONNX\",\n        )\n        self.add_argument(\n            \"--quantize\",\n            action=\"store_true\",\n            help=\"Quantize the neural network to be run with int8\",\n        )\n        self.add_argument(\"output\")\n\n\ndef generate_identified_filename(filename: Path, identifier: str) -> Path:\n    \"\"\"\n    Append a string-identifier at the end (before the extension, if any) to the provided filepath\n\n    Args:\n        filename: pathlib.Path The actual path object we would like to add an identifier suffix\n        identifier: The suffix to add\n\n    Returns: String with concatenated identifier at the end of the filename\n    \"\"\"\n    return filename.parent.joinpath(filename.stem + identifier).with_suffix(filename.suffix)\n\n\ndef check_onnxruntime_requirements(minimum_version: Version):\n    \"\"\"\n    Check onnxruntime is installed and if the installed version match is recent enough\n\n    Raises:\n        ImportError: If onnxruntime is not installed or too old version is found\n    \"\"\"\n    try:\n        import onnxruntime\n\n        # Parse the version of the installed onnxruntime\n        ort_version = parse(onnxruntime.__version__)\n\n        # We require 1.4.0 minimum\n        if ort_version < ORT_QUANTIZE_MINIMUM_VERSION:\n            raise ImportError(\n                f\"We found an older version of onnxruntime ({onnxruntime.__version__}) \"\n                f\"but we require onnxruntime to be >= {minimum_version} to enable all the conversions options.\\n\"\n                f\"Please update onnxruntime by running `pip install --upgrade onnxruntime`\"\n            )\n\n    except ImportError:\n        raise ImportError(\n            \"onnxruntime doesn't seem to be currently installed. \"\n            \"Please install the onnxruntime by running `pip install onnxruntime`\"\n            \" and relaunch the conversion.\"\n        )\n\n\ndef ensure_valid_input(model, tokens, input_names):\n    \"\"\"\n    Ensure input are presented in the correct order, without any Non\n\n    Args:\n        model: The model used to forward the input data\n        tokens: BatchEncoding holding the input data\n        input_names: The name of the inputs\n\n    Returns: Tuple\n\n    \"\"\"\n    print(\"Ensuring inputs are in correct order\")\n\n    model_args_name = model.forward.__code__.co_varnames\n    model_args, ordered_input_names = [], []\n    for arg_name in model_args_name[1:]:  # start at index 1 to skip \"self\" argument\n        if arg_name in input_names:\n            ordered_input_names.append(arg_name)\n            model_args.append(tokens[arg_name])\n        else:\n            print(f\"{arg_name} is not present in the generated input list.\")\n            break\n\n    print(f\"Generated inputs order: {ordered_input_names}\")\n    return ordered_input_names, tuple(model_args)\n\n\ndef infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], Dict, BatchEncoding]:\n    \"\"\"\n    Attempt to infer the static vs dynamic axes for each input and output tensors for a specific model\n\n    Args:\n        nlp: The pipeline object holding the model to be exported\n        framework: The framework identifier to dispatch to the correct inference scheme (pt/tf)\n\n    Returns:\n\n        - List of the inferred input variable names\n        - List of the inferred output variable names\n        - Dictionary with input/output variables names as key and shape tensor as value\n        - a BatchEncoding reference which was used to infer all the above information\n    \"\"\"\n\n    def build_shape_dict(name: str, tensor, is_input: bool, seq_len: int):\n        if isinstance(tensor, (tuple, list)):\n            return [build_shape_dict(name, t, is_input, seq_len) for t in tensor]\n\n        else:\n            # Let's assume batch is the first axis with only 1 element (~~ might not be always true ...)\n            axes = {[axis for axis, numel in enumerate(tensor.shape) if numel == 1][0]: \"batch\"}\n            if is_input:\n                if len(tensor.shape) == 2:\n                    axes[1] = \"sequence\"\n                else:\n                    raise ValueError(f\"Unable to infer tensor axes ({len(tensor.shape)})\")\n            else:\n                seq_axes = [dim for dim, shape in enumerate(tensor.shape) if shape == seq_len]\n                axes.update({dim: \"sequence\" for dim in seq_axes})\n\n        print(f\"Found {'input' if is_input else 'output'} {name} with shape: {axes}\")\n        return axes\n\n    tokens = nlp.tokenizer(\"This is a sample output\", return_tensors=framework)\n    seq_len = tokens.input_ids.shape[-1]\n    outputs = nlp.model(**tokens) if framework == \"pt\" else nlp.model(tokens)\n    if isinstance(outputs, ModelOutput):\n        outputs = outputs.to_tuple()\n    if not isinstance(outputs, (list, tuple)):\n        outputs = (outputs,)\n\n    # Generate input names & axes\n    input_vars = list(tokens.keys())\n    input_dynamic_axes = {k: build_shape_dict(k, v, True, seq_len) for k, v in tokens.items()}\n\n    # flatten potentially grouped outputs (past for gpt2, attentions)\n    outputs_flat = []\n    for output in outputs:\n        if isinstance(output, (tuple, list)):\n            outputs_flat.extend(output)\n        else:\n            outputs_flat.append(output)\n\n    # Generate output names & axes\n    output_names = [f\"output_{i}\" for i in range(len(outputs_flat))]\n    output_dynamic_axes = {k: build_shape_dict(k, v, False, seq_len) for k, v in zip(output_names, outputs_flat)}\n\n    # Create the aggregated axes representation\n    dynamic_axes = dict(input_dynamic_axes, **output_dynamic_axes)\n    return input_vars, output_names, dynamic_axes, tokens\n\n\ndef load_graph_from_args(\n    pipeline_name: str, framework: str, model: str, tokenizer: Optional[str] = None, **models_kwargs\n) -> Pipeline:\n    \"\"\"\n    Convert the set of arguments provided through the CLI to an actual pipeline reference (tokenizer + model\n\n    Args:\n        pipeline_name: The kind of pipeline to use (ner, question-answering, etc.)\n        framework: The actual model to convert the pipeline from (\"pt\" or \"tf\")\n        model: The model name which will be loaded by the pipeline\n        tokenizer: The tokenizer name which will be loaded by the pipeline, default to the model's value\n\n    Returns: Pipeline object\n\n    \"\"\"\n    # If no tokenizer provided\n    if tokenizer is None:\n        tokenizer = model\n\n    # Check the wanted framework is available\n    if framework == \"pt\" and not is_torch_available():\n        raise Exception(\"Cannot convert because PyTorch is not installed. Please install torch first.\")\n    if framework == \"tf\" and not is_tf_available():\n        raise Exception(\"Cannot convert because TF is not installed. Please install tensorflow first.\")\n\n    print(f\"Loading pipeline (model: {model}, tokenizer: {tokenizer})\")\n\n    # Allocate tokenizer and model\n    return pipeline(pipeline_name, model=model, tokenizer=tokenizer, framework=framework, model_kwargs=models_kwargs)\n\n\ndef convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format: bool):\n    \"\"\"\n    Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR\n\n    Args:\n        nlp: The pipeline to be exported\n        opset: The actual version of the ONNX operator set to use\n        output: Path where will be stored the generated ONNX model\n        use_external_format: Split the model definition from its parameters to allow model bigger than 2GB\n\n    Returns:\n\n    \"\"\"\n    if not is_torch_available():\n        raise Exception(\"Cannot convert because PyTorch is not installed. Please install torch first.\")\n\n    import torch\n    from torch.onnx import export\n\n    print(f\"Using framework PyTorch: {torch.__version__}\")\n\n    with torch.no_grad():\n        input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, \"pt\")\n        ordered_input_names, model_args = ensure_valid_input(nlp.model, tokens, input_names)\n\n        print('Exporting from PyTorch to ONNX...')\n        print('input_names', input_names)\n        print('output_names', output_names)\n        print('dynamic_axes', dynamic_axes)\n        print('tokens', tokens)\n        print('model_args', model_args)\n        \n        export(\n            nlp.model,\n            model_args,\n            f=output.as_posix(),\n            input_names=ordered_input_names,\n            output_names=output_names,\n            dynamic_axes=dynamic_axes,\n            do_constant_folding=True,\n            use_external_data_format=use_external_format,\n            enable_onnx_checker=True,\n            opset_version=opset,\n            verbose=True\n        )\n\n\ndef convert_tensorflow(nlp: Pipeline, opset: int, output: Path):\n    \"\"\"\n    Export a TensorFlow backed pipeline to ONNX Intermediate Representation (IR\n\n    Args:\n        nlp: The pipeline to be exported\n        opset: The actual version of the ONNX operator set to use\n        output: Path where will be stored the generated ONNX model\n\n    Notes: TensorFlow cannot export model bigger than 2GB due to internal constraint from TensorFlow\n\n    \"\"\"\n    if not is_tf_available():\n        raise Exception(\"Cannot convert because TF is not installed. Please install tensorflow first.\")\n\n    print(\"/!\\\\ Please note TensorFlow doesn't support exporting model > 2Gb /!\\\\\")\n\n    try:\n        import tensorflow as tf\n\n        from keras2onnx import __version__ as k2ov\n        from keras2onnx import convert_keras, save_model\n\n        print(f\"Using framework TensorFlow: {tf.version.VERSION}, keras2onnx: {k2ov}\")\n\n        # Build\n        input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, \"tf\")\n\n        # Forward\n        nlp.model.predict(tokens.data)\n        onnx_model = convert_keras(nlp.model, nlp.model.name, target_opset=opset)\n        save_model(onnx_model, output.as_posix())\n\n    except ImportError as e:\n        raise Exception(f\"Cannot import {e.name} required to convert TF model to ONNX. Please install {e.name} first.\")\n\n\ndef convert(\n    framework: str,\n    model: str,\n    output: Path,\n    opset: int,\n    tokenizer: Optional[str] = None,\n    use_external_format: bool = False,\n    pipeline_name: str = \"feature-extraction\",\n    save_config: bool = False,\n    **model_kwargs\n):\n    \"\"\"\n    Convert the pipeline object to the ONNX Intermediate Representation (IR) format\n\n    Args:\n        framework: The framework the pipeline is backed by (\"pt\" or \"tf\")\n        model: The name of the model to load for the pipeline\n        output: The path where the ONNX graph will be stored\n        opset: The actual version of the ONNX operator set to use\n        tokenizer: The name of the model to load for the pipeline, default to the model's name if not provided\n        use_external_format: Split the model definition from its parameters to allow model bigger than 2GB (PyTorch only)\n        pipeline_name: The kind of pipeline to instantiate (ner, question-answering, etc.)\n        model_kwargs: Keyword arguments to be forwarded to the model constructor\n\n    Returns:\n\n    \"\"\"\n    print(f\"ONNX opset version set to: {opset}\")\n\n    # Load the pipeline\n    nlp = load_graph_from_args(pipeline_name, framework, model, tokenizer, **model_kwargs)\n\n    if not output.parent.exists():\n        print(f\"Creating folder {output.parent}\")\n        makedirs(output.parent.as_posix())\n    #elif len(listdir(output.parent.as_posix())) > 0:\n    #    raise Exception(f\"Folder {output.parent.as_posix()} is not empty, aborting conversion\")\n\n    # Export the graph\n    if framework == \"pt\":\n        convert_pytorch(nlp, opset, output, use_external_format)\n    else:\n        convert_tensorflow(nlp, opset, output)\n        \n    # Save the configuration\n    if save_config:\n        config_path = os.path.splitext(output)[0] + '.json'\n\n        config = dict(\n            model = nlp.model.config.to_dict(),\n            tokenizer = nlp.tokenizer.init_kwargs\n        )\n        \n        #nlp.model.config.to_json_file(config_path)\n        \n        with open(config_path, 'w') as config_file:\n            json.dump(config, config_file, indent=2)\n            \n        print(f\"Saved config to {config_path}\")\n        \n\ndef optimize(onnx_model_path: Path) -> Path:\n    \"\"\"\n    Load the model at the specified path and let onnxruntime look at transformations on the graph to enable all the\n    optimizations possibl\n\n    Args:\n        onnx_model_path: filepath where the model binary description is stored\n\n    Returns: Path where the optimized model binary description has been saved\n\n    \"\"\"\n    from onnxruntime import InferenceSession, SessionOptions\n\n    # Generate model name with suffix \"optimized\"\n    opt_model_path = generate_identified_filename(onnx_model_path, \"-optimized\")\n    sess_option = SessionOptions()\n    sess_option.optimized_model_filepath = opt_model_path.as_posix()\n    _ = InferenceSession(onnx_model_path.as_posix(), sess_option)\n\n    print(f\"Optimized model has been written at {opt_model_path}: \\N{heavy check mark}\")\n    print(\"/!\\\\ Optimized model contains hardware specific operators which might not be portable. /!\\\\\")\n\n    return opt_model_path\n\n\ndef quantize(onnx_model_path: Path) -> Path:\n    \"\"\"\n    Quantize the weights of the model from float32 to in8 to allow very efficient inference on modern CPU\n\n    Args:\n        onnx_model_path: Path to location the exported ONNX model is stored\n\n    Returns: The Path generated for the quantized\n    \"\"\"\n    import onnx\n    from onnxruntime.quantization import QuantizationMode, quantize\n\n    onnx_model = onnx.load(onnx_model_path.as_posix())\n\n    # Discussed with @yufenglee from ONNX runtime, this will be address in the next release of onnxruntime\n    print(\n        \"As of onnxruntime 1.4.0, models larger than 2GB will fail to quantize due to protobuf constraint.\\n\"\n        \"This limitation will be removed in the next release of onnxruntime.\"\n    )\n\n    quantized_model = quantize(\n        model=onnx_model,\n        quantization_mode=QuantizationMode.IntegerOps,\n        force_fusions=True,\n        symmetric_weight=True,\n    )\n\n    # Append \"-quantized\" at the end of the model's name\n    quantized_model_path = generate_identified_filename(onnx_model_path, \"-quantized\")\n\n    # Save model\n    print(f\"Quantized model has been written at {quantized_model_path}: \\N{heavy check mark}\")\n    onnx.save_model(quantized_model, quantized_model_path.as_posix())\n\n    return quantized_model_path\n\n\ndef verify(path: Path):\n    from onnxruntime import InferenceSession, SessionOptions\n    from onnxruntime.capi.onnxruntime_pybind11_state import RuntimeException\n\n    print(f\"Checking ONNX model loading from: {path} ...\")\n    try:\n        onnx_options = SessionOptions()\n        _ = InferenceSession(path.as_posix(), onnx_options, providers=[\"CPUExecutionProvider\"])\n        print(f\"Model {path} correctly loaded: \\N{heavy check mark}\")\n    except RuntimeException as re:\n        print(f\"Error while loading the model {re}: \\N{heavy ballot x}\")\n\n\nif __name__ == \"__main__\":\n    parser = OnnxConverterArgumentParser()\n    args = parser.parse_args()\n\n    # Make sure output is absolute path\n    args.output = Path(args.output).absolute()\n\n    try:\n        print(\"\\n====== Converting model to ONNX ======\")\n        # Convert\n        convert(\n            args.framework,\n            args.model,\n            args.output,\n            args.opset,\n            args.tokenizer,\n            args.use_external_format,\n            args.pipeline,\n            args.save_config\n        )\n\n        if args.quantize:\n            # Ensure requirements for quantization on onnxruntime is met\n            check_onnxruntime_requirements(ORT_QUANTIZE_MINIMUM_VERSION)\n\n            # onnxruntime optimizations doesn't provide the same level of performances on TensorFlow than PyTorch\n            if args.framework == \"tf\":\n                print(\n                    \"\\t Using TensorFlow might not provide the same optimization level compared to PyTorch.\\n\"\n                    \"\\t For TensorFlow users you can try optimizing the model directly through onnxruntime_tools.\\n\"\n                    \"\\t For more information, please refer to the onnxruntime documentation:\\n\"\n                    \"\\t\\thttps://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/transformers\\n\"\n                )\n\n            print(\"\\n====== Optimizing ONNX model ======\")\n\n            # Quantization works best when using the optimized version of the model\n            args.optimized_output = optimize(args.output)\n\n            # Do the quantization on the right graph\n            args.quantized_output = quantize(args.optimized_output)\n\n        # And verify\n        if args.check_loading:\n            print(\"\\n====== Check exported ONNX model(s) ======\")\n            verify(args.output)\n\n            if hasattr(args, \"optimized_output\"):\n                verify(args.optimized_output)\n\n            if hasattr(args, \"quantized_output\"):\n                verify(args.quantized_output)\n\n    except Exception as e:\n        print(f\"Error while converting the model: {e}\")\n        exit(1)\n"
  },
  {
    "path": "patches/transformers/4.5.1/modeling_distilbert.diff",
    "content": "183,184c183,184\n<         mask = (mask == 0).view(mask_reshp).expand_as(scores)  # (bs, n_heads, q_length, k_length)\n<         scores.masked_fill_(mask, -float(\"inf\"))  # (bs, n_heads, q_length, k_length)\n---\n>         mask = mask.view(mask_reshp).expand_as(scores)  # (bs, n_heads, q_length, k_length)\n>         scores.masked_fill_((mask == 0), -float(\"inf\"))  # (bs, n_heads, q_length, k_length)\n"
  },
  {
    "path": "patches/transformers/4.5.1/modeling_distilbert.original.py",
    "content": "# coding=utf-8\n# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\n PyTorch DistilBERT model adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM) and in\n part from HuggingFace PyTorch version of Google AI Bert model (https://github.com/google-research/bert)\n\"\"\"\n\n\nimport copy\nimport math\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom torch.nn import CrossEntropyLoss\n\nfrom ...activations import gelu\nfrom ...file_utils import (\n    add_code_sample_docstrings,\n    add_start_docstrings,\n    add_start_docstrings_to_model_forward,\n    replace_return_docstrings,\n)\nfrom ...modeling_outputs import (\n    BaseModelOutput,\n    MaskedLMOutput,\n    MultipleChoiceModelOutput,\n    QuestionAnsweringModelOutput,\n    SequenceClassifierOutput,\n    TokenClassifierOutput,\n)\nfrom ...modeling_utils import (\n    PreTrainedModel,\n    apply_chunking_to_forward,\n    find_pruneable_heads_and_indices,\n    prune_linear_layer,\n)\nfrom ...utils import logging\nfrom .configuration_distilbert import DistilBertConfig\n\n\nlogger = logging.get_logger(__name__)\n_CHECKPOINT_FOR_DOC = \"distilbert-base-uncased\"\n_CONFIG_FOR_DOC = \"DistilBertConfig\"\n_TOKENIZER_FOR_DOC = \"DistilBertTokenizer\"\n\nDISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [\n    \"distilbert-base-uncased\",\n    \"distilbert-base-uncased-distilled-squad\",\n    \"distilbert-base-cased\",\n    \"distilbert-base-cased-distilled-squad\",\n    \"distilbert-base-german-cased\",\n    \"distilbert-base-multilingual-cased\",\n    \"distilbert-base-uncased-finetuned-sst-2-english\",\n    # See all DistilBERT models at https://huggingface.co/models?filter=distilbert\n]\n\n\n# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #\n\n\ndef create_sinusoidal_embeddings(n_pos, dim, out):\n    position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])\n    out.requires_grad = False\n    out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))\n    out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))\n    out.detach_()\n\n\nclass Embeddings(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id)\n        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim)\n        if config.sinusoidal_pos_embds:\n            create_sinusoidal_embeddings(\n                n_pos=config.max_position_embeddings, dim=config.dim, out=self.position_embeddings.weight\n            )\n\n        self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)\n        self.dropout = nn.Dropout(config.dropout)\n\n    def forward(self, input_ids):\n        \"\"\"\n        Parameters:\n            input_ids: torch.tensor(bs, max_seq_length) The token ids to embed.\n\n        Returns: torch.tensor(bs, max_seq_length, dim) The embedded tokens (plus position embeddings, no token_type\n        embeddings)\n        \"\"\"\n        seq_length = input_ids.size(1)\n        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)  # (max_seq_length)\n        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)  # (bs, max_seq_length)\n\n        word_embeddings = self.word_embeddings(input_ids)  # (bs, max_seq_length, dim)\n        position_embeddings = self.position_embeddings(position_ids)  # (bs, max_seq_length, dim)\n\n        embeddings = word_embeddings + position_embeddings  # (bs, max_seq_length, dim)\n        embeddings = self.LayerNorm(embeddings)  # (bs, max_seq_length, dim)\n        embeddings = self.dropout(embeddings)  # (bs, max_seq_length, dim)\n        return embeddings\n\n\nclass MultiHeadSelfAttention(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n\n        self.n_heads = config.n_heads\n        self.dim = config.dim\n        self.dropout = nn.Dropout(p=config.attention_dropout)\n\n        assert self.dim % self.n_heads == 0\n\n        self.q_lin = nn.Linear(in_features=config.dim, out_features=config.dim)\n        self.k_lin = nn.Linear(in_features=config.dim, out_features=config.dim)\n        self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)\n        self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)\n\n        self.pruned_heads = set()\n\n    def prune_heads(self, heads):\n        attention_head_size = self.dim // self.n_heads\n        if len(heads) == 0:\n            return\n        heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, attention_head_size, self.pruned_heads)\n        # Prune linear layers\n        self.q_lin = prune_linear_layer(self.q_lin, index)\n        self.k_lin = prune_linear_layer(self.k_lin, index)\n        self.v_lin = prune_linear_layer(self.v_lin, index)\n        self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)\n        # Update hyper params\n        self.n_heads = self.n_heads - len(heads)\n        self.dim = attention_head_size * self.n_heads\n        self.pruned_heads = self.pruned_heads.union(heads)\n\n    def forward(self, query, key, value, mask, head_mask=None, output_attentions=False):\n        \"\"\"\n        Parameters:\n            query: torch.tensor(bs, seq_length, dim)\n            key: torch.tensor(bs, seq_length, dim)\n            value: torch.tensor(bs, seq_length, dim)\n            mask: torch.tensor(bs, seq_length)\n\n        Returns:\n            weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,\n            seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`\n        \"\"\"\n        bs, q_length, dim = query.size()\n        k_length = key.size(1)\n        # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured'\n        # assert key.size() == value.size()\n\n        dim_per_head = self.dim // self.n_heads\n\n        mask_reshp = (bs, 1, 1, k_length)\n\n        def shape(x):\n            \"\"\" separate heads \"\"\"\n            return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)\n\n        def unshape(x):\n            \"\"\" group heads \"\"\"\n            return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)\n\n        q = shape(self.q_lin(query))  # (bs, n_heads, q_length, dim_per_head)\n        k = shape(self.k_lin(key))  # (bs, n_heads, k_length, dim_per_head)\n        v = shape(self.v_lin(value))  # (bs, n_heads, k_length, dim_per_head)\n\n        q = q / math.sqrt(dim_per_head)  # (bs, n_heads, q_length, dim_per_head)\n        scores = torch.matmul(q, k.transpose(2, 3))  # (bs, n_heads, q_length, k_length)\n        mask = (mask == 0).view(mask_reshp).expand_as(scores)  # (bs, n_heads, q_length, k_length)\n        scores.masked_fill_(mask, -float(\"inf\"))  # (bs, n_heads, q_length, k_length)\n\n        weights = nn.Softmax(dim=-1)(scores)  # (bs, n_heads, q_length, k_length)\n        weights = self.dropout(weights)  # (bs, n_heads, q_length, k_length)\n\n        # Mask heads if we want to\n        if head_mask is not None:\n            weights = weights * head_mask\n\n        context = torch.matmul(weights, v)  # (bs, n_heads, q_length, dim_per_head)\n        context = unshape(context)  # (bs, q_length, dim)\n        context = self.out_lin(context)  # (bs, q_length, dim)\n\n        if output_attentions:\n            return (context, weights)\n        else:\n            return (context,)\n\n\nclass FFN(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dropout = nn.Dropout(p=config.dropout)\n        self.chunk_size_feed_forward = config.chunk_size_feed_forward\n        self.seq_len_dim = 1\n        self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim)\n        self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim)\n        assert config.activation in [\"relu\", \"gelu\"], f\"activation ({config.activation}) must be in ['relu', 'gelu']\"\n        self.activation = gelu if config.activation == \"gelu\" else nn.ReLU()\n\n    def forward(self, input):\n        return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)\n\n    def ff_chunk(self, input):\n        x = self.lin1(input)\n        x = self.activation(x)\n        x = self.lin2(x)\n        x = self.dropout(x)\n        return x\n\n\nclass TransformerBlock(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n\n        assert config.dim % config.n_heads == 0\n\n        self.attention = MultiHeadSelfAttention(config)\n        self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)\n\n        self.ffn = FFN(config)\n        self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)\n\n    def forward(self, x, attn_mask=None, head_mask=None, output_attentions=False):\n        \"\"\"\n        Parameters:\n            x: torch.tensor(bs, seq_length, dim)\n            attn_mask: torch.tensor(bs, seq_length)\n\n        Returns:\n            sa_weights: torch.tensor(bs, n_heads, seq_length, seq_length) The attention weights ffn_output:\n            torch.tensor(bs, seq_length, dim) The output of the transformer block contextualization.\n        \"\"\"\n        # Self-Attention\n        sa_output = self.attention(\n            query=x,\n            key=x,\n            value=x,\n            mask=attn_mask,\n            head_mask=head_mask,\n            output_attentions=output_attentions,\n        )\n        if output_attentions:\n            sa_output, sa_weights = sa_output  # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)\n        else:  # To handle these `output_attentions` or `output_hidden_states` cases returning tuples\n            assert type(sa_output) == tuple\n            sa_output = sa_output[0]\n        sa_output = self.sa_layer_norm(sa_output + x)  # (bs, seq_length, dim)\n\n        # Feed Forward Network\n        ffn_output = self.ffn(sa_output)  # (bs, seq_length, dim)\n        ffn_output = self.output_layer_norm(ffn_output + sa_output)  # (bs, seq_length, dim)\n\n        output = (ffn_output,)\n        if output_attentions:\n            output = (sa_weights,) + output\n        return output\n\n\nclass Transformer(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.n_layers = config.n_layers\n\n        layer = TransformerBlock(config)\n        self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layers)])\n\n    def forward(\n        self, x, attn_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False, return_dict=None\n    ):  # docstyle-ignore\n        \"\"\"\n        Parameters:\n            x: torch.tensor(bs, seq_length, dim) Input sequence embedded.\n            attn_mask: torch.tensor(bs, seq_length) Attention mask on the sequence.\n\n        Returns:\n            hidden_state: torch.tensor(bs, seq_length, dim) Sequence of hidden states in the last (top)\n            layer all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]\n                Tuple of length n_layers with the hidden states from each layer.\n                Optional: only if output_hidden_states=True\n            all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]\n                Tuple of length n_layers with the attention weights from each layer\n                Optional: only if output_attentions=True\n        \"\"\"\n        all_hidden_states = () if output_hidden_states else None\n        all_attentions = () if output_attentions else None\n\n        hidden_state = x\n        for i, layer_module in enumerate(self.layer):\n            if output_hidden_states:\n                all_hidden_states = all_hidden_states + (hidden_state,)\n\n            layer_outputs = layer_module(\n                x=hidden_state, attn_mask=attn_mask, head_mask=head_mask[i], output_attentions=output_attentions\n            )\n            hidden_state = layer_outputs[-1]\n\n            if output_attentions:\n                assert len(layer_outputs) == 2\n                attentions = layer_outputs[0]\n                all_attentions = all_attentions + (attentions,)\n            else:\n                assert len(layer_outputs) == 1\n\n        # Add last layer\n        if output_hidden_states:\n            all_hidden_states = all_hidden_states + (hidden_state,)\n\n        if not return_dict:\n            return tuple(v for v in [hidden_state, all_hidden_states, all_attentions] if v is not None)\n        return BaseModelOutput(\n            last_hidden_state=hidden_state, hidden_states=all_hidden_states, attentions=all_attentions\n        )\n\n\n# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #\nclass DistilBertPreTrainedModel(PreTrainedModel):\n    \"\"\"\n    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained\n    models.\n    \"\"\"\n\n    config_class = DistilBertConfig\n    load_tf_weights = None\n    base_model_prefix = \"distilbert\"\n\n    def _init_weights(self, module):\n        \"\"\"Initialize the weights.\"\"\"\n        if isinstance(module, nn.Linear):\n            # Slightly different from the TF version which uses truncated_normal for initialization\n            # cf https://github.com/pytorch/pytorch/pull/5617\n            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n        elif isinstance(module, nn.LayerNorm):\n            module.bias.data.zero_()\n            module.weight.data.fill_(1.0)\n\n\nDISTILBERT_START_DOCSTRING = r\"\"\"\n\n    This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic\n    methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,\n    pruning heads etc.)\n\n    This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__\n    subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to\n    general usage and behavior.\n\n    Parameters:\n        config (:class:`~transformers.DistilBertConfig`): Model configuration class with all the parameters of the model.\n            Initializing with a config file does not load the weights associated with the model, only the\n            configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model\n            weights.\n\"\"\"\n\nDISTILBERT_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):\n            Indices of input sequence tokens in the vocabulary.\n\n            Indices can be obtained using :class:`~transformers.DistilBertTokenizer`. See\n            :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for\n            details.\n\n            `What are input IDs? <../glossary.html#input-ids>`__\n        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):\n            Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            `What are attention masks? <../glossary.html#attention-mask>`__\n        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):\n            Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n\n        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):\n            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.\n            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated\n            vectors than the model's internal embedding lookup matrix.\n        output_attentions (:obj:`bool`, `optional`):\n            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned\n            tensors for more detail.\n        output_hidden_states (:obj:`bool`, `optional`):\n            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for\n            more detail.\n        return_dict (:obj:`bool`, `optional`):\n            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare DistilBERT encoder/transformer outputting raw hidden-states without any specific head on top.\",\n    DISTILBERT_START_DOCSTRING,\n)\nclass DistilBertModel(DistilBertPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n\n        self.embeddings = Embeddings(config)  # Embeddings\n        self.transformer = Transformer(config)  # Encoder\n\n        self.init_weights()\n\n    def get_input_embeddings(self):\n        return self.embeddings.word_embeddings\n\n    def set_input_embeddings(self, new_embeddings):\n        self.embeddings.word_embeddings = new_embeddings\n\n    def _prune_heads(self, heads_to_prune):\n        \"\"\"\n        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base\n        class PreTrainedModel\n        \"\"\"\n        for layer, heads in heads_to_prune.items():\n            self.transformer.layer[layer].attention.prune_heads(heads)\n\n    @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format(\"batch_size, num_choices\"))\n    @add_code_sample_docstrings(\n        tokenizer_class=_TOKENIZER_FOR_DOC,\n        checkpoint=_CHECKPOINT_FOR_DOC,\n        output_type=BaseModelOutput,\n        config_class=_CONFIG_FOR_DOC,\n    )\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        head_mask=None,\n        inputs_embeds=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ):\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\"You cannot specify both input_ids and inputs_embeds at the same time\")\n        elif input_ids is not None:\n            input_shape = input_ids.size()\n        elif inputs_embeds is not None:\n            input_shape = inputs_embeds.size()[:-1]\n        else:\n            raise ValueError(\"You have to specify either input_ids or inputs_embeds\")\n\n        device = input_ids.device if input_ids is not None else inputs_embeds.device\n\n        if attention_mask is None:\n            attention_mask = torch.ones(input_shape, device=device)  # (bs, seq_length)\n\n        # Prepare head mask if needed\n        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embeddings(input_ids)  # (bs, seq_length, dim)\n        return self.transformer(\n            x=inputs_embeds,\n            attn_mask=attention_mask,\n            head_mask=head_mask,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"DistilBert Model with a `masked language modeling` head on top. \"\"\",\n    DISTILBERT_START_DOCSTRING,\n)\nclass DistilBertForMaskedLM(DistilBertPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n\n        self.distilbert = DistilBertModel(config)\n        self.vocab_transform = nn.Linear(config.dim, config.dim)\n        self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)\n        self.vocab_projector = nn.Linear(config.dim, config.vocab_size)\n\n        self.init_weights()\n\n        self.mlm_loss_fct = nn.CrossEntropyLoss()\n\n    def get_output_embeddings(self):\n        return self.vocab_projector\n\n    def set_output_embeddings(self, new_embeddings):\n        self.vocab_projector = new_embeddings\n\n    @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format(\"batch_size, num_choices\"))\n    @add_code_sample_docstrings(\n        tokenizer_class=_TOKENIZER_FOR_DOC,\n        checkpoint=_CHECKPOINT_FOR_DOC,\n        output_type=MaskedLMOutput,\n        config_class=_CONFIG_FOR_DOC,\n    )\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        head_mask=None,\n        inputs_embeds=None,\n        labels=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ):\n        r\"\"\"\n        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,\n            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored\n            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``.\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        dlbrt_output = self.distilbert(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = dlbrt_output[0]  # (bs, seq_length, dim)\n        prediction_logits = self.vocab_transform(hidden_states)  # (bs, seq_length, dim)\n        prediction_logits = gelu(prediction_logits)  # (bs, seq_length, dim)\n        prediction_logits = self.vocab_layer_norm(prediction_logits)  # (bs, seq_length, dim)\n        prediction_logits = self.vocab_projector(prediction_logits)  # (bs, seq_length, vocab_size)\n\n        mlm_loss = None\n        if labels is not None:\n            mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)), labels.view(-1))\n\n        if not return_dict:\n            output = (prediction_logits,) + dlbrt_output[1:]\n            return ((mlm_loss,) + output) if mlm_loss is not None else output\n\n        return MaskedLMOutput(\n            loss=mlm_loss,\n            logits=prediction_logits,\n            hidden_states=dlbrt_output.hidden_states,\n            attentions=dlbrt_output.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\n    DistilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the\n    pooled output) e.g. for GLUE tasks.\n    \"\"\",\n    DISTILBERT_START_DOCSTRING,\n)\nclass DistilBertForSequenceClassification(DistilBertPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n\n        self.distilbert = DistilBertModel(config)\n        self.pre_classifier = nn.Linear(config.dim, config.dim)\n        self.classifier = nn.Linear(config.dim, config.num_labels)\n        self.dropout = nn.Dropout(config.seq_classif_dropout)\n\n        self.init_weights()\n\n    @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format(\"batch_size, num_choices\"))\n    @add_code_sample_docstrings(\n        tokenizer_class=_TOKENIZER_FOR_DOC,\n        checkpoint=_CHECKPOINT_FOR_DOC,\n        output_type=SequenceClassifierOutput,\n        config_class=_CONFIG_FOR_DOC,\n    )\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        head_mask=None,\n        inputs_embeds=None,\n        labels=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ):\n        r\"\"\"\n        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):\n            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,\n            config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),\n            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        distilbert_output = self.distilbert(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_state = distilbert_output[0]  # (bs, seq_len, dim)\n        pooled_output = hidden_state[:, 0]  # (bs, dim)\n        pooled_output = self.pre_classifier(pooled_output)  # (bs, dim)\n        pooled_output = nn.ReLU()(pooled_output)  # (bs, dim)\n        pooled_output = self.dropout(pooled_output)  # (bs, dim)\n        logits = self.classifier(pooled_output)  # (bs, num_labels)\n\n        loss = None\n        if labels is not None:\n            if self.num_labels == 1:\n                loss_fct = nn.MSELoss()\n                loss = loss_fct(logits.view(-1), labels.view(-1))\n            else:\n                loss_fct = nn.CrossEntropyLoss()\n                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))\n\n        if not return_dict:\n            output = (logits,) + distilbert_output[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return SequenceClassifierOutput(\n            loss=loss,\n            logits=logits,\n            hidden_states=distilbert_output.hidden_states,\n            attentions=distilbert_output.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\n    DistilBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a\n    linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).\n    \"\"\",\n    DISTILBERT_START_DOCSTRING,\n)\nclass DistilBertForQuestionAnswering(DistilBertPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n\n        self.distilbert = DistilBertModel(config)\n        self.qa_outputs = nn.Linear(config.dim, config.num_labels)\n        assert config.num_labels == 2\n        self.dropout = nn.Dropout(config.qa_dropout)\n\n        self.init_weights()\n\n    @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format(\"batch_size, num_choices\"))\n    @add_code_sample_docstrings(\n        tokenizer_class=_TOKENIZER_FOR_DOC,\n        checkpoint=_CHECKPOINT_FOR_DOC,\n        output_type=QuestionAnsweringModelOutput,\n        config_class=_CONFIG_FOR_DOC,\n    )\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        head_mask=None,\n        inputs_embeds=None,\n        start_positions=None,\n        end_positions=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ):\n        r\"\"\"\n        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):\n            Labels for position (index) of the start of the labelled span for computing the token classification loss.\n            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the\n            sequence are not taken into account for computing the loss.\n        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):\n            Labels for position (index) of the end of the labelled span for computing the token classification loss.\n            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the\n            sequence are not taken into account for computing the loss.\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        distilbert_output = self.distilbert(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = distilbert_output[0]  # (bs, max_query_len, dim)\n\n        hidden_states = self.dropout(hidden_states)  # (bs, max_query_len, dim)\n        logits = self.qa_outputs(hidden_states)  # (bs, max_query_len, 2)\n        start_logits, end_logits = logits.split(1, dim=-1)\n        start_logits = start_logits.squeeze(-1)  # (bs, max_query_len)\n        end_logits = end_logits.squeeze(-1)  # (bs, max_query_len)\n\n        total_loss = None\n        if start_positions is not None and end_positions is not None:\n            # If we are on multi-GPU, split add a dimension\n            if len(start_positions.size()) > 1:\n                start_positions = start_positions.squeeze(-1)\n            if len(end_positions.size()) > 1:\n                end_positions = end_positions.squeeze(-1)\n            # sometimes the start/end positions are outside our model inputs, we ignore these terms\n            ignored_index = start_logits.size(1)\n            start_positions.clamp_(0, ignored_index)\n            end_positions.clamp_(0, ignored_index)\n\n            loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)\n            start_loss = loss_fct(start_logits, start_positions)\n            end_loss = loss_fct(end_logits, end_positions)\n            total_loss = (start_loss + end_loss) / 2\n\n        if not return_dict:\n            output = (start_logits, end_logits) + distilbert_output[1:]\n            return ((total_loss,) + output) if total_loss is not None else output\n\n        return QuestionAnsweringModelOutput(\n            loss=total_loss,\n            start_logits=start_logits,\n            end_logits=end_logits,\n            hidden_states=distilbert_output.hidden_states,\n            attentions=distilbert_output.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\n    DistilBert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.\n    for Named-Entity-Recognition (NER) tasks.\n    \"\"\",\n    DISTILBERT_START_DOCSTRING,\n)\nclass DistilBertForTokenClassification(DistilBertPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n\n        self.distilbert = DistilBertModel(config)\n        self.dropout = nn.Dropout(config.dropout)\n        self.classifier = nn.Linear(config.hidden_size, config.num_labels)\n\n        self.init_weights()\n\n    @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING)\n    @add_code_sample_docstrings(\n        tokenizer_class=_TOKENIZER_FOR_DOC,\n        checkpoint=_CHECKPOINT_FOR_DOC,\n        output_type=TokenClassifierOutput,\n        config_class=_CONFIG_FOR_DOC,\n    )\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        head_mask=None,\n        inputs_embeds=None,\n        labels=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ):\n        r\"\"\"\n        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n            Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -\n            1]``.\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        outputs = self.distilbert(\n            input_ids,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        sequence_output = outputs[0]\n\n        sequence_output = self.dropout(sequence_output)\n        logits = self.classifier(sequence_output)\n\n        loss = None\n        if labels is not None:\n            loss_fct = CrossEntropyLoss()\n            # Only keep active parts of the loss\n            if attention_mask is not None:\n                active_loss = attention_mask.view(-1) == 1\n                active_logits = logits.view(-1, self.num_labels)\n                active_labels = torch.where(\n                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)\n                )\n                loss = loss_fct(active_logits, active_labels)\n            else:\n                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return TokenClassifierOutput(\n            loss=loss,\n            logits=logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\n    DistilBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and\n    a softmax) e.g. for RocStories/SWAG tasks.\n    \"\"\",\n    DISTILBERT_START_DOCSTRING,\n)\nclass DistilBertForMultipleChoice(DistilBertPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n\n        self.distilbert = DistilBertModel(config)\n        self.pre_classifier = nn.Linear(config.dim, config.dim)\n        self.classifier = nn.Linear(config.dim, 1)\n        self.dropout = nn.Dropout(config.seq_classif_dropout)\n\n        self.init_weights()\n\n    @add_start_docstrings_to_model_forward(\n        DISTILBERT_INPUTS_DOCSTRING.format(\"batch_size, num_choices, sequence_length\")\n    )\n    @replace_return_docstrings(output_type=MultipleChoiceModelOutput, config_class=_CONFIG_FOR_DOC)\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        head_mask=None,\n        inputs_embeds=None,\n        labels=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ):\n        r\"\"\"\n        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):\n            Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,\n            num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See\n            :obj:`input_ids` above)\n\n        Returns:\n\n        Examples::\n\n            >>> from transformers import DistilBertTokenizer, DistilBertForMultipleChoice\n            >>> import torch\n\n            >>> tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')\n            >>> model = DistilBertForMultipleChoice.from_pretrained('distilbert-base-cased')\n\n            >>> prompt = \"In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced.\"\n            >>> choice0 = \"It is eaten with a fork and a knife.\"\n            >>> choice1 = \"It is eaten while held in the hand.\"\n            >>> labels = torch.tensor(0).unsqueeze(0)  # choice0 is correct (according to Wikipedia ;)), batch size 1\n\n            >>> encoding = tokenizer([[prompt, choice0], [prompt, choice1]], return_tensors='pt', padding=True)\n            >>> outputs = model(**{k: v.unsqueeze(0) for k,v in encoding.items()}, labels=labels) # batch size is 1\n\n            >>> # the linear classifier still needs to be trained\n            >>> loss = outputs.loss\n            >>> logits = outputs.logits\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n        num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]\n\n        input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None\n        attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None\n        inputs_embeds = (\n            inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))\n            if inputs_embeds is not None\n            else None\n        )\n\n        outputs = self.distilbert(\n            input_ids,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        hidden_state = outputs[0]  # (bs * num_choices, seq_len, dim)\n        pooled_output = hidden_state[:, 0]  # (bs * num_choices, dim)\n        pooled_output = self.pre_classifier(pooled_output)  # (bs * num_choices, dim)\n        pooled_output = nn.ReLU()(pooled_output)  # (bs * num_choices, dim)\n        pooled_output = self.dropout(pooled_output)  # (bs * num_choices, dim)\n        logits = self.classifier(pooled_output)  # (bs * num_choices, 1)\n\n        reshaped_logits = logits.view(-1, num_choices)  # (bs, num_choices)\n\n        loss = None\n        if labels is not None:\n            loss_fct = CrossEntropyLoss()\n            loss = loss_fct(reshaped_logits, labels)\n\n        if not return_dict:\n            output = (reshaped_logits,) + outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return MultipleChoiceModelOutput(\n            loss=loss,\n            logits=reshaped_logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n"
  },
  {
    "path": "patches/transformers/4.5.1/modeling_distilbert.py",
    "content": "# coding=utf-8\n# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\n PyTorch DistilBERT model adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM) and in\n part from HuggingFace PyTorch version of Google AI Bert model (https://github.com/google-research/bert)\n\"\"\"\n\n\nimport copy\nimport math\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom torch.nn import CrossEntropyLoss\n\nfrom ...activations import gelu\nfrom ...file_utils import (\n    add_code_sample_docstrings,\n    add_start_docstrings,\n    add_start_docstrings_to_model_forward,\n    replace_return_docstrings,\n)\nfrom ...modeling_outputs import (\n    BaseModelOutput,\n    MaskedLMOutput,\n    MultipleChoiceModelOutput,\n    QuestionAnsweringModelOutput,\n    SequenceClassifierOutput,\n    TokenClassifierOutput,\n)\nfrom ...modeling_utils import (\n    PreTrainedModel,\n    apply_chunking_to_forward,\n    find_pruneable_heads_and_indices,\n    prune_linear_layer,\n)\nfrom ...utils import logging\nfrom .configuration_distilbert import DistilBertConfig\n\n\nlogger = logging.get_logger(__name__)\n_CHECKPOINT_FOR_DOC = \"distilbert-base-uncased\"\n_CONFIG_FOR_DOC = \"DistilBertConfig\"\n_TOKENIZER_FOR_DOC = \"DistilBertTokenizer\"\n\nDISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [\n    \"distilbert-base-uncased\",\n    \"distilbert-base-uncased-distilled-squad\",\n    \"distilbert-base-cased\",\n    \"distilbert-base-cased-distilled-squad\",\n    \"distilbert-base-german-cased\",\n    \"distilbert-base-multilingual-cased\",\n    \"distilbert-base-uncased-finetuned-sst-2-english\",\n    # See all DistilBERT models at https://huggingface.co/models?filter=distilbert\n]\n\n\n# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #\n\n\ndef create_sinusoidal_embeddings(n_pos, dim, out):\n    position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])\n    out.requires_grad = False\n    out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))\n    out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))\n    out.detach_()\n\n\nclass Embeddings(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id)\n        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim)\n        if config.sinusoidal_pos_embds:\n            create_sinusoidal_embeddings(\n                n_pos=config.max_position_embeddings, dim=config.dim, out=self.position_embeddings.weight\n            )\n\n        self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)\n        self.dropout = nn.Dropout(config.dropout)\n\n    def forward(self, input_ids):\n        \"\"\"\n        Parameters:\n            input_ids: torch.tensor(bs, max_seq_length) The token ids to embed.\n\n        Returns: torch.tensor(bs, max_seq_length, dim) The embedded tokens (plus position embeddings, no token_type\n        embeddings)\n        \"\"\"\n        seq_length = input_ids.size(1)\n        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)  # (max_seq_length)\n        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)  # (bs, max_seq_length)\n\n        word_embeddings = self.word_embeddings(input_ids)  # (bs, max_seq_length, dim)\n        position_embeddings = self.position_embeddings(position_ids)  # (bs, max_seq_length, dim)\n\n        embeddings = word_embeddings + position_embeddings  # (bs, max_seq_length, dim)\n        embeddings = self.LayerNorm(embeddings)  # (bs, max_seq_length, dim)\n        embeddings = self.dropout(embeddings)  # (bs, max_seq_length, dim)\n        return embeddings\n\n\nclass MultiHeadSelfAttention(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n\n        self.n_heads = config.n_heads\n        self.dim = config.dim\n        self.dropout = nn.Dropout(p=config.attention_dropout)\n\n        assert self.dim % self.n_heads == 0\n\n        self.q_lin = nn.Linear(in_features=config.dim, out_features=config.dim)\n        self.k_lin = nn.Linear(in_features=config.dim, out_features=config.dim)\n        self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)\n        self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)\n\n        self.pruned_heads = set()\n\n    def prune_heads(self, heads):\n        attention_head_size = self.dim // self.n_heads\n        if len(heads) == 0:\n            return\n        heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, attention_head_size, self.pruned_heads)\n        # Prune linear layers\n        self.q_lin = prune_linear_layer(self.q_lin, index)\n        self.k_lin = prune_linear_layer(self.k_lin, index)\n        self.v_lin = prune_linear_layer(self.v_lin, index)\n        self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)\n        # Update hyper params\n        self.n_heads = self.n_heads - len(heads)\n        self.dim = attention_head_size * self.n_heads\n        self.pruned_heads = self.pruned_heads.union(heads)\n\n    def forward(self, query, key, value, mask, head_mask=None, output_attentions=False):\n        \"\"\"\n        Parameters:\n            query: torch.tensor(bs, seq_length, dim)\n            key: torch.tensor(bs, seq_length, dim)\n            value: torch.tensor(bs, seq_length, dim)\n            mask: torch.tensor(bs, seq_length)\n\n        Returns:\n            weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,\n            seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`\n        \"\"\"\n        bs, q_length, dim = query.size()\n        k_length = key.size(1)\n        # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured'\n        # assert key.size() == value.size()\n\n        dim_per_head = self.dim // self.n_heads\n\n        mask_reshp = (bs, 1, 1, k_length)\n\n        def shape(x):\n            \"\"\" separate heads \"\"\"\n            return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)\n\n        def unshape(x):\n            \"\"\" group heads \"\"\"\n            return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)\n\n        q = shape(self.q_lin(query))  # (bs, n_heads, q_length, dim_per_head)\n        k = shape(self.k_lin(key))  # (bs, n_heads, k_length, dim_per_head)\n        v = shape(self.v_lin(value))  # (bs, n_heads, k_length, dim_per_head)\n\n        q = q / math.sqrt(dim_per_head)  # (bs, n_heads, q_length, dim_per_head)\n        scores = torch.matmul(q, k.transpose(2, 3))  # (bs, n_heads, q_length, k_length)\n        mask = mask.view(mask_reshp).expand_as(scores)  # (bs, n_heads, q_length, k_length)\n        scores.masked_fill_((mask == 0), -float(\"inf\"))  # (bs, n_heads, q_length, k_length)\n\n        weights = nn.Softmax(dim=-1)(scores)  # (bs, n_heads, q_length, k_length)\n        weights = self.dropout(weights)  # (bs, n_heads, q_length, k_length)\n\n        # Mask heads if we want to\n        if head_mask is not None:\n            weights = weights * head_mask\n\n        context = torch.matmul(weights, v)  # (bs, n_heads, q_length, dim_per_head)\n        context = unshape(context)  # (bs, q_length, dim)\n        context = self.out_lin(context)  # (bs, q_length, dim)\n\n        if output_attentions:\n            return (context, weights)\n        else:\n            return (context,)\n\n\nclass FFN(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dropout = nn.Dropout(p=config.dropout)\n        self.chunk_size_feed_forward = config.chunk_size_feed_forward\n        self.seq_len_dim = 1\n        self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim)\n        self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim)\n        assert config.activation in [\"relu\", \"gelu\"], f\"activation ({config.activation}) must be in ['relu', 'gelu']\"\n        self.activation = gelu if config.activation == \"gelu\" else nn.ReLU()\n\n    def forward(self, input):\n        return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)\n\n    def ff_chunk(self, input):\n        x = self.lin1(input)\n        x = self.activation(x)\n        x = self.lin2(x)\n        x = self.dropout(x)\n        return x\n\n\nclass TransformerBlock(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n\n        assert config.dim % config.n_heads == 0\n\n        self.attention = MultiHeadSelfAttention(config)\n        self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)\n\n        self.ffn = FFN(config)\n        self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)\n\n    def forward(self, x, attn_mask=None, head_mask=None, output_attentions=False):\n        \"\"\"\n        Parameters:\n            x: torch.tensor(bs, seq_length, dim)\n            attn_mask: torch.tensor(bs, seq_length)\n\n        Returns:\n            sa_weights: torch.tensor(bs, n_heads, seq_length, seq_length) The attention weights ffn_output:\n            torch.tensor(bs, seq_length, dim) The output of the transformer block contextualization.\n        \"\"\"\n        # Self-Attention\n        sa_output = self.attention(\n            query=x,\n            key=x,\n            value=x,\n            mask=attn_mask,\n            head_mask=head_mask,\n            output_attentions=output_attentions,\n        )\n        if output_attentions:\n            sa_output, sa_weights = sa_output  # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)\n        else:  # To handle these `output_attentions` or `output_hidden_states` cases returning tuples\n            assert type(sa_output) == tuple\n            sa_output = sa_output[0]\n        sa_output = self.sa_layer_norm(sa_output + x)  # (bs, seq_length, dim)\n\n        # Feed Forward Network\n        ffn_output = self.ffn(sa_output)  # (bs, seq_length, dim)\n        ffn_output = self.output_layer_norm(ffn_output + sa_output)  # (bs, seq_length, dim)\n\n        output = (ffn_output,)\n        if output_attentions:\n            output = (sa_weights,) + output\n        return output\n\n\nclass Transformer(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.n_layers = config.n_layers\n\n        layer = TransformerBlock(config)\n        self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layers)])\n\n    def forward(\n        self, x, attn_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False, return_dict=None\n    ):  # docstyle-ignore\n        \"\"\"\n        Parameters:\n            x: torch.tensor(bs, seq_length, dim) Input sequence embedded.\n            attn_mask: torch.tensor(bs, seq_length) Attention mask on the sequence.\n\n        Returns:\n            hidden_state: torch.tensor(bs, seq_length, dim) Sequence of hidden states in the last (top)\n            layer all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]\n                Tuple of length n_layers with the hidden states from each layer.\n                Optional: only if output_hidden_states=True\n            all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]\n                Tuple of length n_layers with the attention weights from each layer\n                Optional: only if output_attentions=True\n        \"\"\"\n        all_hidden_states = () if output_hidden_states else None\n        all_attentions = () if output_attentions else None\n\n        hidden_state = x\n        for i, layer_module in enumerate(self.layer):\n            if output_hidden_states:\n                all_hidden_states = all_hidden_states + (hidden_state,)\n\n            layer_outputs = layer_module(\n                x=hidden_state, attn_mask=attn_mask, head_mask=head_mask[i], output_attentions=output_attentions\n            )\n            hidden_state = layer_outputs[-1]\n\n            if output_attentions:\n                assert len(layer_outputs) == 2\n                attentions = layer_outputs[0]\n                all_attentions = all_attentions + (attentions,)\n            else:\n                assert len(layer_outputs) == 1\n\n        # Add last layer\n        if output_hidden_states:\n            all_hidden_states = all_hidden_states + (hidden_state,)\n\n        if not return_dict:\n            return tuple(v for v in [hidden_state, all_hidden_states, all_attentions] if v is not None)\n        return BaseModelOutput(\n            last_hidden_state=hidden_state, hidden_states=all_hidden_states, attentions=all_attentions\n        )\n\n\n# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #\nclass DistilBertPreTrainedModel(PreTrainedModel):\n    \"\"\"\n    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained\n    models.\n    \"\"\"\n\n    config_class = DistilBertConfig\n    load_tf_weights = None\n    base_model_prefix = \"distilbert\"\n\n    def _init_weights(self, module):\n        \"\"\"Initialize the weights.\"\"\"\n        if isinstance(module, nn.Linear):\n            # Slightly different from the TF version which uses truncated_normal for initialization\n            # cf https://github.com/pytorch/pytorch/pull/5617\n            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n        elif isinstance(module, nn.LayerNorm):\n            module.bias.data.zero_()\n            module.weight.data.fill_(1.0)\n\n\nDISTILBERT_START_DOCSTRING = r\"\"\"\n\n    This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic\n    methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,\n    pruning heads etc.)\n\n    This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__\n    subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to\n    general usage and behavior.\n\n    Parameters:\n        config (:class:`~transformers.DistilBertConfig`): Model configuration class with all the parameters of the model.\n            Initializing with a config file does not load the weights associated with the model, only the\n            configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model\n            weights.\n\"\"\"\n\nDISTILBERT_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):\n            Indices of input sequence tokens in the vocabulary.\n\n            Indices can be obtained using :class:`~transformers.DistilBertTokenizer`. See\n            :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for\n            details.\n\n            `What are input IDs? <../glossary.html#input-ids>`__\n        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):\n            Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            `What are attention masks? <../glossary.html#attention-mask>`__\n        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):\n            Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n\n        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):\n            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.\n            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated\n            vectors than the model's internal embedding lookup matrix.\n        output_attentions (:obj:`bool`, `optional`):\n            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned\n            tensors for more detail.\n        output_hidden_states (:obj:`bool`, `optional`):\n            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for\n            more detail.\n        return_dict (:obj:`bool`, `optional`):\n            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare DistilBERT encoder/transformer outputting raw hidden-states without any specific head on top.\",\n    DISTILBERT_START_DOCSTRING,\n)\nclass DistilBertModel(DistilBertPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n\n        self.embeddings = Embeddings(config)  # Embeddings\n        self.transformer = Transformer(config)  # Encoder\n\n        self.init_weights()\n\n    def get_input_embeddings(self):\n        return self.embeddings.word_embeddings\n\n    def set_input_embeddings(self, new_embeddings):\n        self.embeddings.word_embeddings = new_embeddings\n\n    def _prune_heads(self, heads_to_prune):\n        \"\"\"\n        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base\n        class PreTrainedModel\n        \"\"\"\n        for layer, heads in heads_to_prune.items():\n            self.transformer.layer[layer].attention.prune_heads(heads)\n\n    @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format(\"batch_size, num_choices\"))\n    @add_code_sample_docstrings(\n        tokenizer_class=_TOKENIZER_FOR_DOC,\n        checkpoint=_CHECKPOINT_FOR_DOC,\n        output_type=BaseModelOutput,\n        config_class=_CONFIG_FOR_DOC,\n    )\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        head_mask=None,\n        inputs_embeds=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ):\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\"You cannot specify both input_ids and inputs_embeds at the same time\")\n        elif input_ids is not None:\n            input_shape = input_ids.size()\n        elif inputs_embeds is not None:\n            input_shape = inputs_embeds.size()[:-1]\n        else:\n            raise ValueError(\"You have to specify either input_ids or inputs_embeds\")\n\n        device = input_ids.device if input_ids is not None else inputs_embeds.device\n\n        if attention_mask is None:\n            attention_mask = torch.ones(input_shape, device=device)  # (bs, seq_length)\n\n        # Prepare head mask if needed\n        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embeddings(input_ids)  # (bs, seq_length, dim)\n        return self.transformer(\n            x=inputs_embeds,\n            attn_mask=attention_mask,\n            head_mask=head_mask,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"DistilBert Model with a `masked language modeling` head on top. \"\"\",\n    DISTILBERT_START_DOCSTRING,\n)\nclass DistilBertForMaskedLM(DistilBertPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n\n        self.distilbert = DistilBertModel(config)\n        self.vocab_transform = nn.Linear(config.dim, config.dim)\n        self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)\n        self.vocab_projector = nn.Linear(config.dim, config.vocab_size)\n\n        self.init_weights()\n\n        self.mlm_loss_fct = nn.CrossEntropyLoss()\n\n    def get_output_embeddings(self):\n        return self.vocab_projector\n\n    def set_output_embeddings(self, new_embeddings):\n        self.vocab_projector = new_embeddings\n\n    @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format(\"batch_size, num_choices\"))\n    @add_code_sample_docstrings(\n        tokenizer_class=_TOKENIZER_FOR_DOC,\n        checkpoint=_CHECKPOINT_FOR_DOC,\n        output_type=MaskedLMOutput,\n        config_class=_CONFIG_FOR_DOC,\n    )\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        head_mask=None,\n        inputs_embeds=None,\n        labels=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ):\n        r\"\"\"\n        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,\n            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored\n            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``.\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        dlbrt_output = self.distilbert(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = dlbrt_output[0]  # (bs, seq_length, dim)\n        prediction_logits = self.vocab_transform(hidden_states)  # (bs, seq_length, dim)\n        prediction_logits = gelu(prediction_logits)  # (bs, seq_length, dim)\n        prediction_logits = self.vocab_layer_norm(prediction_logits)  # (bs, seq_length, dim)\n        prediction_logits = self.vocab_projector(prediction_logits)  # (bs, seq_length, vocab_size)\n\n        mlm_loss = None\n        if labels is not None:\n            mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)), labels.view(-1))\n\n        if not return_dict:\n            output = (prediction_logits,) + dlbrt_output[1:]\n            return ((mlm_loss,) + output) if mlm_loss is not None else output\n\n        return MaskedLMOutput(\n            loss=mlm_loss,\n            logits=prediction_logits,\n            hidden_states=dlbrt_output.hidden_states,\n            attentions=dlbrt_output.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\n    DistilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the\n    pooled output) e.g. for GLUE tasks.\n    \"\"\",\n    DISTILBERT_START_DOCSTRING,\n)\nclass DistilBertForSequenceClassification(DistilBertPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n\n        self.distilbert = DistilBertModel(config)\n        self.pre_classifier = nn.Linear(config.dim, config.dim)\n        self.classifier = nn.Linear(config.dim, config.num_labels)\n        self.dropout = nn.Dropout(config.seq_classif_dropout)\n\n        self.init_weights()\n\n    @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format(\"batch_size, num_choices\"))\n    @add_code_sample_docstrings(\n        tokenizer_class=_TOKENIZER_FOR_DOC,\n        checkpoint=_CHECKPOINT_FOR_DOC,\n        output_type=SequenceClassifierOutput,\n        config_class=_CONFIG_FOR_DOC,\n    )\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        head_mask=None,\n        inputs_embeds=None,\n        labels=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ):\n        r\"\"\"\n        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):\n            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,\n            config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),\n            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        distilbert_output = self.distilbert(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_state = distilbert_output[0]  # (bs, seq_len, dim)\n        pooled_output = hidden_state[:, 0]  # (bs, dim)\n        pooled_output = self.pre_classifier(pooled_output)  # (bs, dim)\n        pooled_output = nn.ReLU()(pooled_output)  # (bs, dim)\n        pooled_output = self.dropout(pooled_output)  # (bs, dim)\n        logits = self.classifier(pooled_output)  # (bs, num_labels)\n\n        loss = None\n        if labels is not None:\n            if self.num_labels == 1:\n                loss_fct = nn.MSELoss()\n                loss = loss_fct(logits.view(-1), labels.view(-1))\n            else:\n                loss_fct = nn.CrossEntropyLoss()\n                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))\n\n        if not return_dict:\n            output = (logits,) + distilbert_output[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return SequenceClassifierOutput(\n            loss=loss,\n            logits=logits,\n            hidden_states=distilbert_output.hidden_states,\n            attentions=distilbert_output.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\n    DistilBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a\n    linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).\n    \"\"\",\n    DISTILBERT_START_DOCSTRING,\n)\nclass DistilBertForQuestionAnswering(DistilBertPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n\n        self.distilbert = DistilBertModel(config)\n        self.qa_outputs = nn.Linear(config.dim, config.num_labels)\n        assert config.num_labels == 2\n        self.dropout = nn.Dropout(config.qa_dropout)\n\n        self.init_weights()\n\n    @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format(\"batch_size, num_choices\"))\n    @add_code_sample_docstrings(\n        tokenizer_class=_TOKENIZER_FOR_DOC,\n        checkpoint=_CHECKPOINT_FOR_DOC,\n        output_type=QuestionAnsweringModelOutput,\n        config_class=_CONFIG_FOR_DOC,\n    )\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        head_mask=None,\n        inputs_embeds=None,\n        start_positions=None,\n        end_positions=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ):\n        r\"\"\"\n        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):\n            Labels for position (index) of the start of the labelled span for computing the token classification loss.\n            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the\n            sequence are not taken into account for computing the loss.\n        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):\n            Labels for position (index) of the end of the labelled span for computing the token classification loss.\n            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the\n            sequence are not taken into account for computing the loss.\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        distilbert_output = self.distilbert(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = distilbert_output[0]  # (bs, max_query_len, dim)\n\n        hidden_states = self.dropout(hidden_states)  # (bs, max_query_len, dim)\n        logits = self.qa_outputs(hidden_states)  # (bs, max_query_len, 2)\n        start_logits, end_logits = logits.split(1, dim=-1)\n        start_logits = start_logits.squeeze(-1)  # (bs, max_query_len)\n        end_logits = end_logits.squeeze(-1)  # (bs, max_query_len)\n\n        total_loss = None\n        if start_positions is not None and end_positions is not None:\n            # If we are on multi-GPU, split add a dimension\n            if len(start_positions.size()) > 1:\n                start_positions = start_positions.squeeze(-1)\n            if len(end_positions.size()) > 1:\n                end_positions = end_positions.squeeze(-1)\n            # sometimes the start/end positions are outside our model inputs, we ignore these terms\n            ignored_index = start_logits.size(1)\n            start_positions.clamp_(0, ignored_index)\n            end_positions.clamp_(0, ignored_index)\n\n            loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)\n            start_loss = loss_fct(start_logits, start_positions)\n            end_loss = loss_fct(end_logits, end_positions)\n            total_loss = (start_loss + end_loss) / 2\n\n        if not return_dict:\n            output = (start_logits, end_logits) + distilbert_output[1:]\n            return ((total_loss,) + output) if total_loss is not None else output\n\n        return QuestionAnsweringModelOutput(\n            loss=total_loss,\n            start_logits=start_logits,\n            end_logits=end_logits,\n            hidden_states=distilbert_output.hidden_states,\n            attentions=distilbert_output.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\n    DistilBert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.\n    for Named-Entity-Recognition (NER) tasks.\n    \"\"\",\n    DISTILBERT_START_DOCSTRING,\n)\nclass DistilBertForTokenClassification(DistilBertPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n\n        self.distilbert = DistilBertModel(config)\n        self.dropout = nn.Dropout(config.dropout)\n        self.classifier = nn.Linear(config.hidden_size, config.num_labels)\n\n        self.init_weights()\n\n    @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING)\n    @add_code_sample_docstrings(\n        tokenizer_class=_TOKENIZER_FOR_DOC,\n        checkpoint=_CHECKPOINT_FOR_DOC,\n        output_type=TokenClassifierOutput,\n        config_class=_CONFIG_FOR_DOC,\n    )\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        head_mask=None,\n        inputs_embeds=None,\n        labels=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ):\n        r\"\"\"\n        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n            Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -\n            1]``.\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        outputs = self.distilbert(\n            input_ids,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        sequence_output = outputs[0]\n\n        sequence_output = self.dropout(sequence_output)\n        logits = self.classifier(sequence_output)\n\n        loss = None\n        if labels is not None:\n            loss_fct = CrossEntropyLoss()\n            # Only keep active parts of the loss\n            if attention_mask is not None:\n                active_loss = attention_mask.view(-1) == 1\n                active_logits = logits.view(-1, self.num_labels)\n                active_labels = torch.where(\n                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)\n                )\n                loss = loss_fct(active_logits, active_labels)\n            else:\n                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return TokenClassifierOutput(\n            loss=loss,\n            logits=logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\n    DistilBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and\n    a softmax) e.g. for RocStories/SWAG tasks.\n    \"\"\",\n    DISTILBERT_START_DOCSTRING,\n)\nclass DistilBertForMultipleChoice(DistilBertPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n\n        self.distilbert = DistilBertModel(config)\n        self.pre_classifier = nn.Linear(config.dim, config.dim)\n        self.classifier = nn.Linear(config.dim, 1)\n        self.dropout = nn.Dropout(config.seq_classif_dropout)\n\n        self.init_weights()\n\n    @add_start_docstrings_to_model_forward(\n        DISTILBERT_INPUTS_DOCSTRING.format(\"batch_size, num_choices, sequence_length\")\n    )\n    @replace_return_docstrings(output_type=MultipleChoiceModelOutput, config_class=_CONFIG_FOR_DOC)\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        head_mask=None,\n        inputs_embeds=None,\n        labels=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ):\n        r\"\"\"\n        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):\n            Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,\n            num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See\n            :obj:`input_ids` above)\n\n        Returns:\n\n        Examples::\n\n            >>> from transformers import DistilBertTokenizer, DistilBertForMultipleChoice\n            >>> import torch\n\n            >>> tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')\n            >>> model = DistilBertForMultipleChoice.from_pretrained('distilbert-base-cased')\n\n            >>> prompt = \"In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced.\"\n            >>> choice0 = \"It is eaten with a fork and a knife.\"\n            >>> choice1 = \"It is eaten while held in the hand.\"\n            >>> labels = torch.tensor(0).unsqueeze(0)  # choice0 is correct (according to Wikipedia ;)), batch size 1\n\n            >>> encoding = tokenizer([[prompt, choice0], [prompt, choice1]], return_tensors='pt', padding=True)\n            >>> outputs = model(**{k: v.unsqueeze(0) for k,v in encoding.items()}, labels=labels) # batch size is 1\n\n            >>> # the linear classifier still needs to be trained\n            >>> loss = outputs.loss\n            >>> logits = outputs.logits\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n        num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]\n\n        input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None\n        attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None\n        inputs_embeds = (\n            inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))\n            if inputs_embeds is not None\n            else None\n        )\n\n        outputs = self.distilbert(\n            input_ids,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        hidden_state = outputs[0]  # (bs * num_choices, seq_len, dim)\n        pooled_output = hidden_state[:, 0]  # (bs * num_choices, dim)\n        pooled_output = self.pre_classifier(pooled_output)  # (bs * num_choices, dim)\n        pooled_output = nn.ReLU()(pooled_output)  # (bs * num_choices, dim)\n        pooled_output = self.dropout(pooled_output)  # (bs * num_choices, dim)\n        logits = self.classifier(pooled_output)  # (bs * num_choices, 1)\n\n        reshaped_logits = logits.view(-1, num_choices)  # (bs, num_choices)\n\n        loss = None\n        if labels is not None:\n            loss_fct = CrossEntropyLoss()\n            loss = loss_fct(reshaped_logits, labels)\n\n        if not return_dict:\n            output = (reshaped_logits,) + outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return MultipleChoiceModelOutput(\n            loss=loss,\n            logits=reshaped_logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n"
  },
  {
    "path": "ros/CMakeLists.txt",
    "content": "cmake_minimum_required(VERSION 3.5)\nproject(jetson_voice_ros)\n\n# Default to C99\nif(NOT CMAKE_C_STANDARD)\n  set(CMAKE_C_STANDARD 99)\nendif()\n\n# Default to C++14\nif(NOT CMAKE_CXX_STANDARD)\n  set(CMAKE_CXX_STANDARD 14)\nendif()\n\nif(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES \"Clang\")\n  add_compile_options(-Wall -Wextra -Wpedantic)\nendif()\n\n# find dependencies\nfind_package(ament_cmake REQUIRED)\nfind_package(ament_cmake_python REQUIRED)\nfind_package(std_msgs REQUIRED)\nfind_package(rosidl_default_generators REQUIRED)\n\n# generate messages\nrosidl_generate_interfaces(${PROJECT_NAME}\n  \"msg/Audio.msg\"\n  \"msg/AudioInfo.msg\"\n  \"msg/IntentSlot.msg\"\n  \"msg/QuestionAnswerQuery.msg\"\n  \"msg/QuestionAnswerResult.msg\"\n  \"msg/Slot.msg\"\n  DEPENDENCIES std_msgs\n)\n\n# install Python modules\nament_python_install_package(${PROJECT_NAME})\n\n# install Python executables\nfile(GLOB python_nodes ${PROJECT_NAME}/*.py)\n\ninstall(PROGRAMS\n  ${python_nodes}\n  DESTINATION lib/${PROJECT_NAME}\n)\n\n# install launch files\ninstall(DIRECTORY\n  launch\n  DESTINATION share/${PROJECT_NAME}/\n)\n \nif(BUILD_TESTING)\n  find_package(ament_lint_auto REQUIRED)\n  # the following line skips the linter which checks for copyrights\n  # uncomment the line when a copyright and license is not present in all source files\n  #set(ament_cmake_copyright_FOUND TRUE)\n  # the following line skips cpplint (only works in a git repo)\n  # uncomment the line when this package is not in a git repo\n  #set(ament_cmake_cpplint_FOUND TRUE)\n  ament_lint_auto_find_test_dependencies()\nendif()\n\nament_package()\n"
  },
  {
    "path": "ros/jetson_voice_ros/__init__.py",
    "content": ""
  },
  {
    "path": "ros/jetson_voice_ros/asr.py",
    "content": "#!/usr/bin/env python3\nimport os\nimport rclpy\nimport numpy as np\n\nfrom rclpy.node import Node\nfrom std_msgs.msg import String\n\nfrom jetson_voice import ASR\nfrom jetson_voice_ros.msg import Audio\n\n\nclass ASRNode(Node):\n    def __init__(self):\n        super().__init__('asr', namespace='voice')\n        \n        # create topics\n        self.audio_subscriber = self.create_subscription(Audio, 'audio_in', self.audio_listener, 10)\n        self.transcript_publisher = self.create_publisher(String, 'transcripts', 10)\n        self.partial_transcript_publisher = self.create_publisher(String, 'partial_transcripts', 10)\n        \n        # get node parameters\n        self.declare_parameter('model', 'quartznet')\n        self.model_name = str(self.get_parameter('model').value)\n        self.get_logger().info(f'model = {self.model_name}')\n\n        # load the ASR model\n        self.asr = ASR(self.model_name)\n        self.get_logger().info(f\"model '{self.model_name}' ready\")\n        \n        if self.asr.classification:\n            raise ValueError(f'jetson_voice_ros/asr node does not support ASR classification models')\n        \n    def audio_listener(self, msg):\n        if msg.info.sample_rate != self.asr.sample_rate:\n            self.get_logger().warning(f\"audio has sample_rate {msg.info.sample_rate}, but ASR expects sample_rate {self.asr.sample_rate}\")\n            \n        samples = np.frombuffer(msg.data, dtype=msg.info.sample_format)\n        self.get_logger().debug(f'recieved audio samples {samples.shape} dtype={samples.dtype}') # rms={np.sqrt(np.mean(samples**2))}')\n        \n        results = self.asr(samples)\n        \n        for transcript in results:\n            text = transcript['text'].strip()\n            \n            if len(text) == 0:\n                continue\n                \n            msg = String()\n            msg.data = text\n\n            self.get_logger().info(f\"transcript:  {text}\")\n\n            if transcript['end']:\n                self.transcript_publisher.publish(msg)\n                \n            self.partial_transcript_publisher.publish(msg)\n                \n\ndef main(args=None):\n    rclpy.init(args=args)\n    node = ASRNode()\n    rclpy.spin(node)\n    node.destroy_node()\n    rclpy.shutdown()\n\n\nif __name__ == \"__main__\":\n    main()"
  },
  {
    "path": "ros/jetson_voice_ros/audio_input.py",
    "content": "#!/usr/bin/env python3\nimport os\nimport rclpy\nimport numpy as np\n\nfrom rclpy.node import Node\n\nfrom jetson_voice.utils import AudioInput, audio_to_int16\nfrom jetson_voice_ros.msg import Audio\n\n\nclass AudioInputNode(Node):\n    def __init__(self):\n        super().__init__('audio_input', namespace='voice')\n        \n        # create topics\n        self.audio_publisher = self.create_publisher(Audio, 'audio_in', 10)\n        \n        # get node parameters\n        self.declare_parameter('device', '')          # input audio device ID or name\n        self.declare_parameter('sample_rate', 16000)  # sample rate (in Hz)\n        self.declare_parameter('chunk_size', 16000)   # number of samples per buffer\n        self.declare_parameter('resets', -1)          # number of times to reset the device (-1 is infinite)\n        \n        self.device_name = str(self.get_parameter('device').value)\n        self.sample_rate = self.get_parameter('sample_rate').value\n        self.chunk_size = self.get_parameter('chunk_size').value\n        self.resets = self.get_parameter('resets').value\n        \n        self.reset_count = 0\n        \n        if self.device_name == '':\n            raise ValueError(\"must set the 'device' parameter to either an input audio device ID/name or the path to a .wav file\")\n        \n        self.get_logger().info(f'device={self.device_name}')\n        self.get_logger().info(f'sample_rate={self.sample_rate}')\n        self.get_logger().info(f'chunk_size={self.chunk_size}')\n        self.get_logger().info(f'resets={self.resets}')\n        \n        # check if this is an audio device or a wav file\n        file_ext = os.path.splitext(self.device_name)[1].lower()\n        \n        if file_ext == '.wav' or file_ext == '.wave':\n            wav = self.device_name\n            mic = ''\n        else:\n            wav = ''\n            mic = self.device_name\n\n        # create audio device\n        self.device = AudioInput(wav=wav, mic=mic, sample_rate=self.sample_rate, chunk_size=self.chunk_size)\n        self.device.open()\n\n        # create a timer to check for audio samples\n        self.timer = self.create_timer(self.chunk_size / self.sample_rate * 0.75, self.publish_audio)\n        \n        \n    def publish_audio(self):\n    \n        while True:\n            samples = self.device.next()\n            \n            if samples is not None:\n                break\n                \n            self.get_logger().warning('no audio samples were returned from the audio device')\n            \n            if self.resets < 0 or self.reset_count < self.resets:\n                self.reset_count += 1\n                self.get_logger().warning(f'resetting audio device {self.device_name} (attempt {self.reset_count} of {self.resets})')\n                self.device.reset()\n            else:\n                self.get_logger().error(f'maximum audio device resets has been reached ({self.resets})')\n                return\n                \n        if samples.dtype == np.float32:  # convert to int16 to make the message smaller\n            samples = audio_to_int16(samples)\n\n        if samples.dtype != np.int16:  # the other voice nodes expect int16/float32\n            raise ValueError(f'audio samples are expected to have datatype int16, but they were {samples.dtype}')\n        \n        self.get_logger().debug(f'publishing audio samples {samples.shape} dtype={samples.dtype}') # rms={np.sqrt(np.mean(samples**2))}')\n        \n        # publish message\n        msg = Audio()\n        \n        msg.header.stamp = self.get_clock().now().to_msg()\n        msg.header.frame_id = self.device_name\n\n        msg.info.channels = 1  # AudioInput is set to mono\n        msg.info.sample_rate = self.sample_rate\n        msg.info.sample_format = str(samples.dtype)\n        \n        msg.data = samples.tobytes()\n        \n        self.audio_publisher.publish(msg)\n        \n        \ndef main(args=None):\n    rclpy.init(args=args)\n    node = AudioInputNode()\n    rclpy.spin(node)\n    node.destroy_node()\n    rclpy.shutdown()\n\n\nif __name__ == \"__main__\":\n    main()"
  },
  {
    "path": "ros/jetson_voice_ros/audio_output.py",
    "content": "#!/usr/bin/env python3\nimport os\nimport rclpy\nimport numpy as np\n\nfrom rclpy.node import Node\n\nfrom jetson_voice import AudioOutput\nfrom jetson_voice_ros.msg import Audio\n\nfrom soundfile import SoundFile\n\n\nclass AudioOutputNode(Node):\n    def __init__(self):\n        super().__init__('audio_output', namespace='voice')\n        \n        # create topics\n        self.audio_subscriber = self.create_subscription(Audio, 'audio_out', self.audio_listener, 10)\n        \n        # get node parameters\n        self.declare_parameter('device', '')          # input audio device ID or name\n        self.declare_parameter('sample_rate', 16000)  # sample rate (in Hz)\n        self.declare_parameter('chunk_size', 4096)    # number of samples per buffer\n        \n        self.device_name = str(self.get_parameter('device').value)\n        self.sample_rate = self.get_parameter('sample_rate').value\n        self.chunk_size = self.get_parameter('chunk_size').value\n        \n        if self.device_name == '':\n            raise ValueError(\"must set the 'device' parameter to either an input audio device ID/name or the path to a .wav file\")\n        \n        self.get_logger().info(f'device={self.device_name}')\n        self.get_logger().info(f'sample_rate={self.sample_rate}')\n        self.get_logger().info(f'chunk_size={self.chunk_size}')\n        \n        # check if this is an audio device or a wav file\n        file_ext = os.path.splitext(self.device_name)[1].lower()\n        \n        if file_ext == '.wav' or file_ext == '.wave':\n            self.wav = SoundFile(self.device_name, mode='w', samplerate=self.sample_rate, channels=1)\n            self.device = None\n        else:\n            self.wav = None\n            self.device = AudioOutput(self.device_name, sample_rate=self.sample_rate, chunk_size=self.chunk_size)\n\n    def audio_listener(self, msg):\n        #self.get_logger().debug('recieved new audio message')\n        #self.get_logger().debug(f'{msg.header}')\n        #self.get_logger().debug(f'{msg.info}')\n        \n        if msg.info.sample_rate != self.sample_rate:\n            self.get_logger().warning(f\"audio has sample_rate {msg.info.sample_rate}, but audio device is using sample_rate {self.sample_rate}\")\n            \n        samples = np.frombuffer(msg.data, dtype=msg.info.sample_format)\n        \n        self.get_logger().debug(f'recieved audio samples {samples.shape} dtype={samples.dtype}') # rms={np.sqrt(np.mean(samples**2))}')\n        \n        if self.device is not None:\n            self.device.write(samples)\n        else:\n            self.wav.write(samples)\n\n\ndef main(args=None):\n    rclpy.init(args=args)\n    node = AudioOutputNode()\n    rclpy.spin(node)\n    node.destroy_node()\n    rclpy.shutdown()\n\n\nif __name__ == \"__main__\":\n    main()"
  },
  {
    "path": "ros/jetson_voice_ros/nlp_intent_slot.py",
    "content": "#!/usr/bin/env python3\nimport os\nimport rclpy\n\nfrom rclpy.node import Node\nfrom std_msgs.msg import String\n\nfrom jetson_voice import IntentSlot as IntentSlotFactory\nfrom jetson_voice_ros.msg import IntentSlot, Slot\n\n\nclass NLPIntentSlotNode(Node):\n    def __init__(self):\n        super().__init__('nlp_intent_slot', namespace='voice')\n        \n        # create topics\n        self.query_subscriber = self.create_subscription(String, 'intent_slot_query', self.query_listener, 10)\n        self.result_publisher = self.create_publisher(IntentSlot, 'intent_slot_results', 10)\n\n        # get node parameters\n        self.declare_parameter('model', 'distilbert_intent')\n        self.model_name = str(self.get_parameter('model').value)\n        self.get_logger().info(f'model = {self.model_name}')\n\n        # load the IntentSlot model\n        self.model = IntentSlotFactory(self.model_name)\n        self.get_logger().info(f\"model '{self.model_name}' ready\")\n        \n    def query_listener(self, msg):\n        text = msg.data.strip()\n        \n        if len(text) == 0:\n            return\n            \n        self.get_logger().info(f\"running NLP Intent/Slot query:  '{text}'\")\n        \n        # run the model\n        results = self.model(text)\n        \n        self.get_logger().info(f\"intent: '{results['intent']}'\")\n        self.get_logger().info(f\"score:  {results['score']}\")\n        \n        for slot in results['slots']:\n            self.get_logger().info(str(slot))\n\n        # create message\n        msg = IntentSlot()\n        \n        msg.query.data = text\n        msg.intent.data = results['intent']\n        msg.score = float(results['score'])\n        \n        slots = []\n        \n        for slot in results['slots']:\n            slot_msg = Slot()\n            \n            slot_msg.slot.data = slot['slot']\n            slot_msg.text.data = slot['text']\n            slot_msg.score = float(slot['score'])\n            \n            slots.append(slot_msg)\n            \n        msg.slots = tuple(slots)\n        \n        # publish message\n        self.result_publisher.publish(msg)\n\n\ndef main(args=None):\n    rclpy.init(args=args)\n    node = NLPIntentSlotNode()\n    rclpy.spin(node)\n    node.destroy_node()\n    rclpy.shutdown()\n\n\nif __name__ == \"__main__\":\n    main()"
  },
  {
    "path": "ros/jetson_voice_ros/nlp_question_answer.py",
    "content": "#!/usr/bin/env python3\nimport os\nimport rclpy\n\nfrom rclpy.node import Node\nfrom std_msgs.msg import String\n\nfrom jetson_voice import QuestionAnswer as QuestionAnswerFactory\nfrom jetson_voice_ros.msg import QuestionAnswerQuery, QuestionAnswerResult\n\n\nclass NLPQuestionAnswerNode(Node):\n    def __init__(self):\n        super().__init__('nlp_question_answer', namespace='voice')\n        \n        # create topics\n        self.query_subscriber = self.create_subscription(QuestionAnswerQuery, 'question_answer_query', self.query_listener, 10)\n        self.result_publisher = self.create_publisher(QuestionAnswerResult, 'question_answer_results', 10)\n\n        # get node parameters\n        self.declare_parameter('model', 'distilbert_qa_384')\n        self.model_name = str(self.get_parameter('model').value)\n        self.get_logger().info(f'model = {self.model_name}')\n\n        # load the QA model\n        self.model = QuestionAnswerFactory(self.model_name)\n        self.get_logger().info(f\"model '{self.model_name}' ready\")\n        \n    def query_listener(self, msg):\n        question = msg.question.data.strip()\n        context = msg.context.data.strip()\n\n        if len(question) == 0 or len(context) == 0:\n            return\n            \n        self.get_logger().info(f\"running NLP Question/Answer query:\")\n        self.get_logger().info(f\"question:  '{question}'\")\n        self.get_logger().info(f\"context:\")\n        self.get_logger().info(context)\n        \n        # run the model\n        results = self.model((question,context))\n        \n        self.get_logger().info(f\"answer: '{results['answer']}'\")\n        self.get_logger().info(f\"score:  {results['score']}\")\n\n        # create message\n        msg = QuestionAnswerResult()\n        \n        msg.question.data = question\n        msg.answer.data = results['answer']\n        msg.score = results['score']\n        \n        # publish message\n        self.result_publisher.publish(msg)\n\n\ndef main(args=None):\n    rclpy.init(args=args)\n    node = NLPQuestionAnswerNode()\n    rclpy.spin(node)\n    node.destroy_node()\n    rclpy.shutdown()\n\n\nif __name__ == \"__main__\":\n    main()"
  },
  {
    "path": "ros/jetson_voice_ros/tts.py",
    "content": "#!/usr/bin/env python3\nimport os\nimport rclpy\nimport numpy as np\n\nfrom rclpy.node import Node\nfrom std_msgs.msg import String\n\nfrom jetson_voice import TTS\nfrom jetson_voice.utils import audio_to_int16\nfrom jetson_voice_ros.msg import Audio\n\n\nclass TTSNode(Node):\n    def __init__(self):\n        super().__init__('tts', namespace='voice')\n        \n        # create topics\n        self.text_subscriber = self.create_subscription(String, 'tts_text', self.text_listener, 10)\n        self.audio_publisher = self.create_publisher(Audio, 'tts_audio', 10)\n\n        # get node parameters\n        self.declare_parameter('model', 'fastpitch_hifigan')\n        self.model_name = str(self.get_parameter('model').value)\n        self.get_logger().info(f'model = {self.model_name}')\n\n        # load the TTS model\n        self.tts = TTS(self.model_name)\n        self.get_logger().info(f\"model '{self.model_name}' ready\")\n        \n    def text_listener(self, msg):\n        text = msg.data.strip()\n        \n        if len(text) == 0:\n            return\n            \n        self.get_logger().info(f\"running TTS on '{text}'\")\n        \n        samples = self.tts(text)\n        samples = audio_to_int16(samples)\n        \n        # publish message\n        msg = Audio()\n        \n        msg.header.stamp = self.get_clock().now().to_msg()\n        msg.header.frame_id = self.model_name\n\n        msg.info.channels = 1\n        msg.info.sample_rate = self.tts.sample_rate\n        msg.info.sample_format = str(samples.dtype)\n        \n        msg.data = samples.tobytes()\n        \n        self.audio_publisher.publish(msg)\n        \n\ndef main(args=None):\n    rclpy.init(args=args)\n    node = TTSNode()\n    rclpy.spin(node)\n    node.destroy_node()\n    rclpy.shutdown()\n\n\nif __name__ == \"__main__\":\n    main()"
  },
  {
    "path": "ros/launch/asr.launch.py",
    "content": "#\r\n# Launch file for playback of an audio stream or wav file.\r\n#\r\nimport os\r\n\r\nfrom launch import LaunchDescription\r\nfrom launch.actions import IncludeLaunchDescription, DeclareLaunchArgument\r\nfrom launch.launch_description_sources import PythonLaunchDescriptionSource\r\nfrom launch.substitutions import ThisLaunchFileDir, LaunchConfiguration\r\nfrom launch_ros.actions import Node\r\n\r\n\r\ndef generate_launch_description():\r\n    \r\n    log_level = DeclareLaunchArgument('log_level', default_value='info')\r\n    asr_model = DeclareLaunchArgument('model', default_value='quartznet')\r\n    input_device = DeclareLaunchArgument('input_device', default_value='/jetson-voice/data/audio/dusty.wav')\r\n\r\n    audio_input = Node(package='jetson_voice_ros', node_executable='audio_input.py',\r\n                       parameters=[\r\n                            {\"device\": LaunchConfiguration('input_device')},\r\n                       ],\r\n                       arguments=['--ros-args', '--log-level', LaunchConfiguration('log_level')],\r\n                       output='screen', emulate_tty=True)              \r\n     \r\n    asr_node = Node(package='jetson_voice_ros', node_executable='asr.py',\r\n                        parameters=[\r\n                            {\"model\": LaunchConfiguration('model')},\r\n                        ],\r\n                        arguments=['--ros-args', '--log-level', LaunchConfiguration('log_level')],\r\n                        output='screen', emulate_tty=True)  \r\n                        \r\n    return LaunchDescription([\r\n        log_level,\r\n        asr_model,\r\n        input_device,\r\n        audio_input,\r\n        asr_node,\r\n    ])"
  },
  {
    "path": "ros/launch/audio_playback.launch.py",
    "content": "#\r\n# Launch file for playback of an audio stream or wav file.\r\n#\r\nimport os\r\n\r\nfrom launch import LaunchDescription\r\nfrom launch.actions import IncludeLaunchDescription, DeclareLaunchArgument\r\nfrom launch.launch_description_sources import PythonLaunchDescriptionSource\r\nfrom launch.substitutions import ThisLaunchFileDir, LaunchConfiguration\r\nfrom launch_ros.actions import Node\r\n\r\n\r\ndef generate_launch_description():\r\n    \r\n    log_level = DeclareLaunchArgument('log_level', default_value='info')\r\n    \r\n    input_device = DeclareLaunchArgument('input_device', default_value='/jetson-voice/data/audio/dusty.wav')\r\n    output_device = DeclareLaunchArgument('output_device', default_value='/jetson-voice/data/audio/output.wav')\r\n    \r\n    audio_input = Node(package='jetson_voice_ros', node_executable='audio_input.py',\r\n                       parameters=[\r\n                            {\"device\": LaunchConfiguration('input_device')},\r\n                       ],\r\n                       arguments=['--ros-args', '--log-level', LaunchConfiguration('log_level')],\r\n                       output='screen', emulate_tty=True)              \r\n     \r\n    audio_output = Node(package='jetson_voice_ros', node_executable='audio_output.py',\r\n                        parameters=[\r\n                            {\"device\": LaunchConfiguration('output_device')},\r\n                        ],\r\n                        remappings=[\r\n                            (\"/voice/audio_out\", \"/voice/audio_in\"),\r\n                        ],\r\n                        arguments=['--ros-args', '--log-level', LaunchConfiguration('log_level')],\r\n                        output='screen', emulate_tty=True)  \r\n                        \r\n    return LaunchDescription([\r\n        log_level,\r\n        input_device,\r\n        output_device,\r\n        audio_input,\r\n        audio_output,\r\n    ])"
  },
  {
    "path": "ros/launch/tts.launch.py",
    "content": "#\r\n# Launch file for playback of an audio stream or wav file.\r\n#\r\nimport os\r\n\r\nfrom launch import LaunchDescription\r\nfrom launch.actions import IncludeLaunchDescription, DeclareLaunchArgument\r\nfrom launch.launch_description_sources import PythonLaunchDescriptionSource\r\nfrom launch.substitutions import ThisLaunchFileDir, LaunchConfiguration\r\nfrom launch_ros.actions import Node\r\n\r\n\r\ndef generate_launch_description():\r\n    \r\n    log_level = DeclareLaunchArgument('log_level', default_value='info')\r\n    tts_model = DeclareLaunchArgument('model', default_value='fastpitch_hifigan')\r\n    output_device = DeclareLaunchArgument('output_device', default_value='/jetson-voice/data/audio/tts_test.wav')\r\n\r\n    tts_node = Node(package='jetson_voice_ros', node_executable='tts.py',\r\n                        parameters=[\r\n                            {\"model\": LaunchConfiguration('model')},\r\n                        ],\r\n                        arguments=['--ros-args', '--log-level', LaunchConfiguration('log_level')],\r\n                        output='screen', emulate_tty=True)  \r\n   \r\n    audio_output = Node(package='jetson_voice_ros', node_executable='audio_output.py',\r\n                        parameters=[\r\n                            {\"device\": LaunchConfiguration('output_device')},\r\n                            {\"sample_rate\": 22050},\r\n                        ],\r\n                        remappings=[\r\n                            (\"/voice/audio_out\", \"/voice/tts_audio\"),\r\n                        ],\r\n                        arguments=['--ros-args', '--log-level', LaunchConfiguration('log_level')],\r\n                        output='screen', emulate_tty=True)  \r\n                        \r\n    return LaunchDescription([\r\n        log_level,\r\n        tts_model,\r\n        output_device,\r\n        tts_node,\r\n        audio_output,\r\n    ])"
  },
  {
    "path": "ros/msg/Audio.msg",
    "content": "std_msgs/Header header\r\nAudioInfo info\r\nuint8[] data"
  },
  {
    "path": "ros/msg/AudioInfo.msg",
    "content": "# Number of channels\r\nuint8 channels\r\n\r\n# Sampling rate [Hz]\r\nuint32 sample_rate\r\n\r\n# Audio format (e.g. int16, float32)\r\nstring sample_format\r\n\r\n# Audio coding format (e.g. wav, mp3)\r\nstring coding_format"
  },
  {
    "path": "ros/msg/IntentSlot.msg",
    "content": "# the original query text\r\nstd_msgs/String query\r\n\r\n# the classified intent label\r\nstd_msgs/String intent\r\n\r\n# the intent probability between [0,1]\r\nfloat32 score\r\n\r\n# list of slots\r\njetson_voice_ros/Slot[] slots"
  },
  {
    "path": "ros/msg/QuestionAnswerQuery.msg",
    "content": "# the question being asked\r\nstd_msgs/String question\r\n\r\n# the context paragraph\r\nstd_msgs/String context"
  },
  {
    "path": "ros/msg/QuestionAnswerResult.msg",
    "content": "# the question that was asked\r\nstd_msgs/String question\r\n\r\n# the answer to the question\r\nstd_msgs/String answer\r\n\r\n# the confidence of the answer betweeen [0,1]\r\nfloat32 score"
  },
  {
    "path": "ros/msg/Slot.msg",
    "content": "# the slot class label\r\nstd_msgs/String slot\r\n\r\n# the relevant text from the original query\r\nstd_msgs/String text\r\n\r\n# classification probability between [0,1]\r\nfloat32 score"
  },
  {
    "path": "ros/package.xml",
    "content": "<?xml version=\"1.0\"?>\n<?xml-model href=\"http://download.ros.org/schema/package_format3.xsd\" schematypens=\"http://www.w3.org/2001/XMLSchema\"?>\n<package format=\"3\">\n  <name>jetson_voice_ros</name>\n  <version>0.0.0</version>\n  <description>ROS2 nodes for jetson_voice</description>\n  <maintainer email=\"dustinf@nvidia.com\">Dustin Franklin</maintainer>\n  <license>MIT</license>\n\n  <depend>rclpy</depend>\n  <depend>std_msgs</depend>\n  \n  <buildtool_depend>ament_cmake</buildtool_depend>\n  <buildtool_depend>ament_cmake_python</buildtool_depend>\n  \n  <build_depend>rosidl_default_generators</build_depend>\n  <exec_depend>rosidl_default_runtime</exec_depend>\n  <member_of_group>rosidl_interface_packages</member_of_group>\n\n  <test_depend>ament_lint_auto</test_depend>\n  <test_depend>ament_lint_common</test_depend>\n\n  <export>\n    <build_type>ament_cmake</build_type>\n  </export>\n</package>\n"
  },
  {
    "path": "scripts/list_audio_devices.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nfrom jetson_voice import list_audio_devices\n    \nlist_audio_devices()\n   \n    "
  },
  {
    "path": "scripts/list_models.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nfrom jetson_voice import list_models\n    \nlist_models()\n   \n    "
  },
  {
    "path": "scripts/nemo_export_onnx.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nimport os\nimport argparse\nimport pprint\nimport json\n\nimport nemo\nimport nemo.collections.asr as nemo_asr\nimport nemo.collections.nlp as nemo_nlp\nimport nemo.collections.tts as nemo_tts\n\nfrom omegaconf import OmegaConf, open_dict\n\n\nmodel_types = {\n    'asr' : nemo_asr.models.ASRModel,\n    'asr_classification' : nemo_asr.models.ASRModel,\n    'qa' : nemo_nlp.models.QAModel,\n    'intent_slot' : nemo_nlp.models.IntentSlotClassificationModel,\n    'text_classification' : nemo_nlp.models.TextClassificationModel,\n    'token_classification' : nemo_nlp.models.TokenClassificationModel\n}\n\nparser = argparse.ArgumentParser()\n\nparser.add_argument('--type', choices=model_types.keys(), type=str, required=True)\nparser.add_argument('--model', type=str, required=True)   # 'QuartzNet15x5Base-En'\nparser.add_argument('--output', default='', type=str, required=True)\n\nargs = parser.parse_args()\n\nprint('nemo version:', nemo.__version__)\n\n# load model depending on extension/type\nextension = os.path.splitext(args.model)[1].lower()\n\nif extension == '.nemo':\n    model = model_types[args.type].restore_from(args.model)\nelif extension == '.ckpt':\n    model = model_types[args.type].load_from_checkpoint(args.model)\nelse: #elif: len(extension) == 0:\n    model = model_types[args.type].from_pretrained(model_name=args.model)\n#else:\n#    raise ValueError(f'model {args.model} has invalid extension {extension}')\n\n# add type string so we can more easily track this later   \nwith open_dict(model._cfg):\n    model._cfg.type = args.type\n    model._cfg.model_path = os.path.basename(args.output)\n    model._cfg.model_origin = args.model\n    \nprint('')\nprint('###############################################')\nprint('## Model Config')\nprint('###############################################')\npprint.pprint(OmegaConf.to_container(model._cfg))\nprint('')\n\nbase_path = os.path.splitext(args.output)[0]\njson_path = base_path + '.json'\nyaml_path = base_path + '.yaml'\n\n#with open(yaml_path, 'w') as yaml_file:\n#  OmegaConf.save(config=model._cfg, f=yaml_file)\n#  print('saved model config to {:s}'.format(yaml_path))\n  \nwith open(json_path, 'w') as json_file:\n  json.dump(OmegaConf.to_container(model._cfg), json_file, indent=3)\n  print('saved model config to {:s}'.format(json_path))\n  \nmodel.export(args.output, verbose=True)\n\nprint('\\nexported {:s} to {:s}'.format(args.model, args.output))\n"
  },
  {
    "path": "scripts/nemo_list_models.py",
    "content": "#!/usr/bin/env python3\r\n# coding: utf-8\r\n\r\nimport nemo\r\nimport nemo.collections.asr as nemo_asr\r\nimport nemo.collections.nlp as nemo_nlp\r\nimport nemo.collections.tts as nemo_tts\r\n\r\nprint('nemo version:', nemo.__version__)\r\n\r\nasr_archs = [model for model in dir(nemo_asr.models) if model.endswith(\"Model\")]\r\nnlp_archs = [model for model in dir(nemo_nlp.models) if model.endswith(\"Model\")]\r\ntts_archs = [model for model in dir(nemo_tts.models) if model.endswith(\"Model\")]\r\n\r\nprint('ASR architectures:', asr_archs)  \r\nprint('NLP architectures:', nlp_archs)\r\nprint('TTS architectures:', tts_archs)\r\n\r\nfor asr_arch in asr_archs:\r\n    print('')\r\n    print('#####################################################')\r\n    print('## nemo_asr.models.{:s}'.format(asr_arch))\r\n    print('#####################################################')\r\n    print(getattr(nemo_asr.models, asr_arch).list_available_models())\r\n\r\nfor nlp_arch in nlp_archs:\r\n    print('')\r\n    print('#####################################################')\r\n    print('## nemo_nlp.models.{:s}'.format(nlp_arch))\r\n    print('#####################################################')\r\n    print(getattr(nemo_nlp.models, nlp_arch).list_available_models())\r\n    \r\nprint('')\r\nprint('#####################################################')\r\nprint('## nemo_nlp.models.pretrained_lm_models')\r\nprint('#####################################################')\r\nfor model in nemo_nlp.modules.get_pretrained_lm_models_list():\r\n    print(model)\r\n\r\nfor tts_arch in tts_archs:\r\n    print('')\r\n    print('#####################################################')\r\n    print('## nemo_tts.models.{:s}'.format(tts_arch))\r\n    print('#####################################################')\r\n    print(getattr(nemo_tts.models, tts_arch).list_available_models())"
  },
  {
    "path": "scripts/nemo_train_classifier.py",
    "content": "#!/usr/bin/env python3\r\n# coding: utf-8\r\n\r\nimport os\r\nimport argparse\r\nimport torch\r\nimport pytorch_lightning as pl\r\n\r\nfrom omegaconf import OmegaConf\r\n\r\nfrom nemo.utils.exp_manager import exp_manager\r\nfrom nemo.collections import nlp as nemo_nlp\r\n\r\n\"\"\"\r\nExample SST2 'Stanford Sentiment Treebank' dataset from:\r\n    https://gluebenchmark.com/tasks\r\n    https://dl.fbaipublicfiles.com/glue/data/SST-2.zip\r\n    \r\nPre-processing commands:\r\n    sed 1d train.tsv > train_nemo_format.tsv\r\n    sed 1d test.tsv > test_nemo_format.tsv\r\n    sed 1d dev.tsv > dev_nemo_format.tsv\r\n\"\"\"\r\n\r\n# parse args\r\nparser = argparse.ArgumentParser()\r\n\r\nparser.add_argument('--dataset', default='datasets/sentiment/SST-2', type=str)\r\nparser.add_argument('--config', default='config/text_classification_config.yaml', type=str)\r\nparser.add_argument('--model', default='distilbert-base-uncased', type=str) # \"bert-base-uncased\"\r\nparser.add_argument('--classes', default=2, type=int)\r\nparser.add_argument('--epochs', default=5, type=int)\r\nparser.add_argument('--samples', default=-1, type=int)\r\nparser.add_argument('--batch-size', default=32, type=int)\r\nparser.add_argument('--learning-rate', '--lr', default=0.00002, type=float)\r\nparser.add_argument('--max-seq-length', default=128, type=int)\r\n\r\nargs = parser.parse_args()\r\nprint(args)\r\n\r\n# load config\r\nconfig = OmegaConf.load(args.config)\r\nprint(f'loaded config from {args.config}')\r\n\r\n# setup config\r\nconfig.model.train_ds.file_path = os.path.join(args.dataset, 'train_nemo_format.tsv')\r\nconfig.model.validation_ds.file_path = os.path.join(args.dataset, 'dev_nemo_format.tsv')\r\nconfig.model.test_ds.file_path = os.path.join(args.dataset, 'test_nemo_format.tsv')\r\n\r\nconfig.model.dataset.num_classes=2\r\nconfig.model.dataset.max_seq_length = args.max_seq_length\r\n\r\nconfig.model.language_model.pretrained_model_name = args.model\r\nconfig.model.tokenizer.tokenizer_name = args.model\r\n\r\nconfig.model.train_ds.batch_size = args.batch_size\r\nconfig.model.validation_ds.batch_size = args.batch_size\r\nconfig.model.test_ds.batch_size = args.batch_size\r\n\r\nif args.samples >  0:\r\n    config.model.train_ds.num_samples = args.samples\r\n    config.model.validation_ds.num_samples = args.samples\r\n    config.model.test_ds.num_samples = args.samples\r\n\r\nconfig.model.optim.lr = args.learning_rate\r\n\r\nconfig.trainer.gpus = 1 if torch.cuda.is_available() else 0\r\nconfig.trainer.precision = 16 if torch.cuda.is_available() else 32  # For mixed precision training, use precision=16 and amp_level=O1\r\nconfig.trainer.max_epochs = args.epochs\r\nconfig.trainer.accelerator = None   # Remove distributed training flags\r\n\r\nprint(OmegaConf.to_yaml(config))\r\n\r\n# create trainer + model\r\ntrainer = pl.Trainer(**config.trainer)\r\nmodel   = nemo_nlp.models.TextClassificationModel(config.model, trainer=trainer)\r\nexp_dir = str(exp_manager(trainer, config.get(\"exp_manager\", None)))\r\n\r\nprint('experiment directory:', exp_dir)\r\n\r\n# start the training\r\ntrainer.fit(model)\r\n\r\n# test the model\r\neval_checkpoint_path = trainer.checkpoint_callback.best_model_path\r\neval_model = nemo_nlp.models.TextClassificationModel.load_from_checkpoint(checkpoint_path=eval_checkpoint_path)\r\n\r\nprint('loaded checkpoint for eval:', eval_checkpoint_path)\r\n\r\neval_model.setup_test_data(test_data_config=config.model.validation_ds)\r\ntrainer.test(model=eval_model, ckpt_path=None, verbose=True)\r\n\r\n# example inference\r\nqueries = [\r\n    'by the end of no such thing the audience , like beatrice , has a watchful affection for the monster .', \r\n    'director rob marshall went out gunning to make a great one .', \r\n    'uneasy mishmash of styles and genres .'\r\n]\r\n\r\nresults = eval_model.classifytext(\r\n            queries=queries, \r\n            batch_size=len(queries), \r\n            max_seq_length=config.model.dataset.max_seq_length\r\n        )\r\n\r\npred_intents, pred_slots = eval_model.predict_from_examples(queries, config.model.validation_ds)\r\n\r\nprint('The prediction results of some sample queries with the trained model:')\r\n\r\nfor query, result in zip(queries, results):\r\n    print(f'Query : {query}')\r\n    print(f'Predicted label: {result}')\r\n    \r\nprint('\\ndone training:', exp_dir)\r\n\r\n"
  },
  {
    "path": "scripts/nemo_train_intent.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nimport os\nimport argparse\nimport torch\nimport pytorch_lightning as pl\n\nfrom omegaconf import OmegaConf\n\nfrom nemo.utils.exp_manager import exp_manager\nfrom nemo.collections import nlp as nemo_nlp\n\n\"\"\"\nExample dataset from:\n    https://github.com/xliuhw/NLU-Evaluation-Data\n    https://github.com/xliuhw/NLU-Evaluation-Data/archive/master.zip\n    \nCommand used to pre-process the data:\n\n    python3 intent_import_datasets.py \\\n        --dataset_name=assistant \\\n        --source_data_dir=datasets/intent/NLU-Evaluation-Data-master \\\n        --target_data_dir=datasets/intent/NLU-Evaluation-Data-master/nemo_format\n\"\"\"\n\n# parse args\nparser = argparse.ArgumentParser()\n\nparser.add_argument('--dataset', default='data/datasets/NLU-Evaluation-Data-master/nemo_format', type=str)\nparser.add_argument('--config', default='data/config/training/intent_slot_classification_config.yaml', type=str)\nparser.add_argument('--exp-dir', default='data/nemo_experiments', type=str) \nparser.add_argument('--model', default='distilbert-base-uncased', type=str) # \"bert-base-uncased\"\nparser.add_argument('--epochs', default=5, type=int)\nparser.add_argument('--samples', default=-1, type=int)\nparser.add_argument('--batch-size', default=32, type=int)\nparser.add_argument('--learning-rate', '--lr', default=0.00002, type=float)\nparser.add_argument('--max-seq-length', default=50, type=int)\n\nargs = parser.parse_args()\nprint(args)\n\n# load config\nconfig = OmegaConf.load(args.config)\nprint(f'loaded config from {args.config}')\n\n# setup config\nconfig.model.data_dir = args.dataset #os.path.join(args.dataset, 'nemo_format')\n\nconfig.model.language_model.max_seq_length = args.max_seq_length\nconfig.model.language_model.pretrained_model_name = args.model\nconfig.model.tokenizer.tokenizer_name = args.model\n\nconfig.model.train_ds.batch_size = args.batch_size\nconfig.model.validation_ds.batch_size = args.batch_size\nconfig.model.test_ds.batch_size = args.batch_size\n\nif args.samples >  0:\n    config.model.train_ds.num_samples = args.samples\n    config.model.validation_ds.num_samples = args.samples\n    config.model.test_ds.num_samples = args.samples\n\nconfig.model.optim.lr = args.learning_rate\n\nconfig.trainer.gpus = 1 if torch.cuda.is_available() else 0\nconfig.trainer.precision = 16 if torch.cuda.is_available() else 32  # For mixed precision training, use precision=16 and amp_level=O1\nconfig.trainer.max_epochs = args.epochs\nconfig.trainer.accelerator = None   # Remove distributed training flags\n\nprint(OmegaConf.to_yaml(config))\n\n# create trainer + model\ntrainer = pl.Trainer(**config.trainer)\nmodel   = nemo_nlp.models.IntentSlotClassificationModel(config.model, trainer=trainer)\n\n# set experiment directory\nexp_cfg = config.get('exp_manager', None)\nexp_cfg['exp_dir'] = args.exp_dir\nexp_dir = str(exp_manager(trainer, exp_cfg))\n\nprint('experiment directory:', exp_dir)\n\n# start the training\ntrainer.fit(model)\n\n# test the model\neval_checkpoint_path = trainer.checkpoint_callback.best_model_path\neval_model = nemo_nlp.models.IntentSlotClassificationModel.load_from_checkpoint(checkpoint_path=eval_checkpoint_path)\n\nprint('loaded checkpoint for eval:', eval_checkpoint_path)\n\neval_model.setup_test_data(test_data_config=config.model.test_ds)\ntrainer.test(model=eval_model, ckpt_path=None, verbose=True)\n\n# example inference\nqueries = [\n    'set alarm for seven thirty am',\n    'lower volume by fifty percent',\n    'what is my schedule for tomorrow',\n]\n\npred_intents, pred_slots = eval_model.predict_from_examples(queries, config.model.test_ds)\n\nprint('The prediction results of some sample queries with the trained model:')\n\nfor query, intent, slots in zip(queries, pred_intents, pred_slots):\n    print(f'Query : {query}')\n    print(f'Predicted Intent: {intent}')\n    print(f'Predicted Slots: {slots}')\n    \nprint('\\ndone training:', exp_dir)\n\n"
  },
  {
    "path": "scripts/nemo_train_ner.py",
    "content": "#!/usr/bin/env python3\r\n# coding: utf-8\r\n\r\nimport os\r\nimport argparse\r\nimport torch\r\nimport pytorch_lightning as pl\r\n\r\nfrom omegaconf import OmegaConf\r\n\r\nfrom nemo.utils.exp_manager import exp_manager\r\nfrom nemo.collections import nlp as nemo_nlp\r\n\r\n\"\"\"\r\nExample GMB (Groningen Meaning Bank) dataset from:\r\n    https://dldata-public.s3.us-east-2.amazonaws.com/gmb_v_2.2.0_clean.zip\r\n    \r\nThis version of the dataset is already pre-processed, but other IOB format \r\ndata can be converted using the ner_import_iob.py tool.\r\n\"\"\"\r\n\r\n# parse args\r\nparser = argparse.ArgumentParser()\r\n\r\nparser.add_argument('--dataset', default='datasets/ner/gmb_v_2.2.0_clean', type=str)\r\nparser.add_argument('--config', default='config/token_classification_config.yaml', type=str)\r\nparser.add_argument('--model', default='distilbert-base-uncased', type=str) # \"bert-base-uncased\"\r\nparser.add_argument('--epochs', default=5, type=int)\r\nparser.add_argument('--samples', default=-1, type=int)\r\nparser.add_argument('--batch-size', default=32, type=int)\r\nparser.add_argument('--learning-rate', '--lr', default=0.00005, type=float)\r\nparser.add_argument('--max-seq-length', default=128, type=int)\r\n\r\nargs = parser.parse_args()\r\nprint(args)\r\n\r\n# load config\r\nconfig = OmegaConf.load(args.config)\r\nprint(f'loaded config from {args.config}')\r\n\r\n# setup config\r\nconfig.model.dataset.data_dir = args.dataset\r\nconfig.model.dataset.max_seq_length = args.max_seq_length\r\n\r\nconfig.model.language_model.pretrained_model_name = args.model\r\nconfig.model.tokenizer.tokenizer_name = args.model\r\n\r\nconfig.model.train_ds.batch_size = args.batch_size\r\nconfig.model.validation_ds.batch_size = args.batch_size\r\nconfig.model.test_ds.batch_size = args.batch_size\r\n\r\nif args.samples >  0:\r\n    config.model.train_ds.num_samples = args.samples\r\n    config.model.validation_ds.num_samples = args.samples\r\n    config.model.test_ds.num_samples = args.samples\r\n\r\nconfig.model.optim.lr = args.learning_rate\r\n\r\nconfig.trainer.gpus = 1 if torch.cuda.is_available() else 0\r\nconfig.trainer.precision = 16 if torch.cuda.is_available() else 32  # For mixed precision training, use precision=16 and amp_level=O1\r\nconfig.trainer.max_epochs = args.epochs\r\nconfig.trainer.accelerator = None   # Remove distributed training flags\r\n\r\nprint(OmegaConf.to_yaml(config))\r\n\r\n# create trainer + model\r\ntrainer = pl.Trainer(**config.trainer)\r\nmodel   = nemo_nlp.models.TokenClassificationModel(config.model, trainer=trainer)\r\nexp_dir = str(exp_manager(trainer, config.get(\"exp_manager\", None)))\r\n\r\nprint('experiment directory:', exp_dir)\r\n\r\n# start the training\r\ntrainer.fit(model)\r\n\r\n# test the model\r\neval_checkpoint_path = trainer.checkpoint_callback.best_model_path\r\neval_model = nemo_nlp.models.TokenClassificationModel.load_from_checkpoint(checkpoint_path=eval_checkpoint_path)\r\n\r\nprint('loaded checkpoint for eval:', eval_checkpoint_path)\r\n\r\neval_model.setup_test_data(test_data_config=config.model.test_ds)\r\ntrainer.test(model=eval_model, ckpt_path=None, verbose=True)\r\n\r\n# example inference\r\neval_model.evaluate_from_file(\r\n    text_file=os.path.join(args.dataset, 'text_dev.txt'),\r\n    labels_file=os.path.join(args.dataset, 'labels_dev.txt'),\r\n    output_dir=exp_dir,\r\n)\r\n    \r\nprint('\\ndone training:', exp_dir)\r\n\r\n"
  },
  {
    "path": "scripts/nemo_train_qa.py",
    "content": "#!/usr/bin/env python3\r\n# coding: utf-8\r\n\r\nimport os\r\nimport argparse\r\nimport torch\r\nimport pytorch_lightning as pl\r\n\r\nfrom omegaconf import OmegaConf\r\n\r\nfrom nemo.utils.exp_manager import exp_manager\r\nfrom nemo.collections import nlp as nemo_nlp\r\n\r\n\r\n# parse args\r\nparser = argparse.ArgumentParser()\r\n\r\nparser.add_argument('--dataset', default='datasets/squad', type=str)\r\nparser.add_argument('--dataset-version', default='v1.1', type=str)\r\nparser.add_argument('--config', default='config/question_answering_squad_config.yaml', type=str)\r\nparser.add_argument('--model', default='distilbert-base-uncased', type=str) # \"bert-base-uncased\"\r\nparser.add_argument('--epochs', default=1, type=int)\r\nparser.add_argument('--samples', default=-1, type=int) # 5000\r\nparser.add_argument('--batch-size', default=12, type=int)\r\nparser.add_argument('--learning-rate', '--lr', default=0.00003, type=float)\r\nparser.add_argument('--max-seq-length', default=384, type=int)\r\nparser.add_argument('--output', default='', type=str) # defaults to ./nemo_experiments\r\n\r\nargs = parser.parse_args()\r\nprint(args)\r\n\r\n# load config\r\nconfig = OmegaConf.load(args.config)\r\nprint(f'loaded config from {args.config}')\r\n\r\n# setup config\r\nconfig.model.train_ds.file = os.path.join(args.dataset, args.dataset_version, f'train-{args.dataset_version}.json')\r\nconfig.model.validation_ds.file = os.path.join(args.dataset, args.dataset_version, f'dev-{args.dataset_version}.json')\r\nconfig.model.test_ds.file = config.model.validation_ds.file\r\n\r\nconfig.model.language_model.pretrained_model_name = args.model\r\nconfig.model.tokenizer.tokenizer_name = args.model\r\nconfig.model.dataset.max_seq_length = args.max_seq_length\r\n\r\nif config.model.dataset.doc_stride >= config.model.dataset.max_seq_length:\r\n    config.model.dataset.doc_stride = int(config.model.dataset.max_seq_length / 2)\r\n    \r\nconfig.model.train_ds.batch_size = args.batch_size\r\nconfig.model.validation_ds.batch_size = args.batch_size\r\nconfig.model.test_ds.batch_size = args.batch_size\r\n\r\nif args.samples >  0:\r\n    config.model.train_ds.num_samples = args.samples\r\n    config.model.validation_ds.num_samples = args.samples\r\n    config.model.test_ds.num_samples = args.samples\r\n\r\nconfig.model.optim.lr = args.learning_rate\r\n\r\nconfig.trainer.gpus = 1 if torch.cuda.is_available() else 0\r\nconfig.trainer.precision = 16 if torch.cuda.is_available() else 32  # For mixed precision training, use precision=16 and amp_level=O1\r\nconfig.trainer.max_epochs = args.epochs\r\nconfig.trainer.accelerator = None   # Remove distributed training flags\r\n\r\nif args.output != '':\r\n    config.exp_manager.exp_dir = args.output\r\n\r\nprint(OmegaConf.to_yaml(config))\r\n\r\n\r\n# create trainer + model\r\ntrainer = pl.Trainer(**config.trainer)\r\nmodel   = nemo_nlp.models.QAModel(cfg=config.model, trainer=trainer)\r\nexp_dir = str(exp_manager(trainer, config.get(\"exp_manager\", None)))\r\n\r\nprint('experiment directory:', exp_dir)\r\n\r\n# start the training\r\ntrainer.fit(model)\r\n\r\n# test the model\r\nmodel.setup_test_data(test_data_config=config.model.test_ds)\r\ntrainer.test(model)\r\n\r\n# example inference\r\nall_preds, all_nbests = model.inference(file=config.model.test_ds.file, \r\n                                        output_nbest_file=os.path.join(exp_dir, 'output_prediction.json'),\r\n                                        output_prediction_file=os.path.join(exp_dir, 'output_nbest.json'),\r\n                                        batch_size=args.batch_size, \r\n                                        num_samples=10)\r\n\r\nfor _, item in all_preds.items():\r\n    print(f\"question: {item[0]} answer: {item[1]}\")\r\n    \r\nprint('\\ndone training:', exp_dir)\r\n\r\n"
  },
  {
    "path": "scripts/os_version.sh",
    "content": "#!/usr/bin/env bash\n\nARCH=$(uname -i)\necho \"ARCH:  $ARCH\"\n\nif [ $ARCH = \"aarch64\" ]; then\n\tL4T_VERSION_STRING=$(head -n 1 /etc/nv_tegra_release)\n\n\tif [ -z \"$L4T_VERSION_STRING\" ]; then\n\t\techo \"reading L4T version from \\\"dpkg-query --show nvidia-l4t-core\\\"\"\n\n\t\tL4T_VERSION_STRING=$(dpkg-query --showformat='${Version}' --show nvidia-l4t-core)\n\t\tL4T_VERSION_ARRAY=(${L4T_VERSION_STRING//./ })\t\n\n\t\t#echo ${L4T_VERSION_ARRAY[@]}\n\t\t#echo ${#L4T_VERSION_ARRAY[@]}\n\n\t\tL4T_RELEASE=${L4T_VERSION_ARRAY[0]}\n\t\tL4T_REVISION=${L4T_VERSION_ARRAY[1]}\n\telse\n\t\techo \"reading L4T version from /etc/nv_tegra_release\"\n\n\t\tL4T_RELEASE=$(echo $L4T_VERSION_STRING | cut -f 2 -d ' ' | grep -Po '(?<=R)[^;]+')\n\t\tL4T_REVISION=$(echo $L4T_VERSION_STRING | cut -f 2 -d ',' | grep -Po '(?<=REVISION: )[^;]+')\n\tfi\n\n\tL4T_REVISION_MAJOR=${L4T_REVISION:0:1}\n\tL4T_REVISION_MINOR=${L4T_REVISION:2:1}\n\n\tL4T_VERSION=\"$L4T_RELEASE.$L4T_REVISION\"\n\n\techo \"L4T BSP Version:  L4T R$L4T_VERSION\"\nfi\n\n"
  },
  {
    "path": "scripts/record_mic.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nimport sys\nimport signal\nimport argparse\n\nfrom jetson_voice import AudioInput, list_audio_devices\nfrom soundfile import SoundFile\n\n\nparser = argparse.ArgumentParser()\n\nparser.add_argument('--mic', default=None, type=str, required=True, help='device name or number of input microphone')\nparser.add_argument('--output', default=None, type=str, required=True, help='path to output wav/ogg/flac file')\nparser.add_argument('--sample-rate', default=16000, type=int, help='sample rate (in Hz)')\nparser.add_argument('--list-devices', action='store_true', help='list audio input devices')\n\nargs = parser.parse_args()\nprint(args)\n\n# list audio devices\nif args.list_devices:\n    list_audio_devices()\n    sys.exit()\n    \n# setup exit signal handler        \nrecord = True\n\ndef signal_handler(sig, frame):\n    global record\n    record = False\n    print('Ctrl+C recieved, exiting...')\n    \nsignal.signal(signal.SIGINT, signal_handler)\n\n# create the output wav\noutput_wav = SoundFile(args.output, mode='w', samplerate=args.sample_rate, channels=1)\n\n# create the audio device\ninput_mic = AudioInput(mic=args.mic, sample_rate=args.sample_rate, chunk_size=4096)\n        \n# loop until user exits\nsample_count = 0\n\nwhile record:\n    samples = input_mic.next()\n    output_wav.write(samples)\n    sample_count += len(samples)\n\noutput_wav.close()\nprint(f\"saved {sample_count / args.sample_rate:.2f} seconds of audio to '{args.output}'\")\n"
  },
  {
    "path": "scripts/start_jupyter.sh",
    "content": "#!/usr/bin/env bash\n\njupyter lab --ip 0.0.0.0 --port 8888 --allow-root &> /var/log/jupyter.log\n\necho \"allow 10 sec for JupyterLab to start @ http://$(hostname -I | cut -d' ' -f1):8888 (password nvidia)\"\necho \"JupterLab logging location:  /var/log/jupyter.log  (inside the container)\"\n\n"
  },
  {
    "path": "tests/run_tests.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nimport os\nimport sys\nimport json\nimport logging\nimport argparse\nimport datetime\nimport subprocess\n\nparser = argparse.ArgumentParser()\n\nparser.add_argument('--log-dir', default='', type=str, help='directory to save log files under')\nparser.add_argument('--tests', default='data/tests/tests.json', type=str, help='path to config file of tests')\nparser.add_argument('--model', default='', type=str, help='if specified, only run tests that use this model')\nparser.add_argument('--module', default='', type=str, help='if specified, only run tests that use this module')\nparser.add_argument('--config', default='', type=str, help='if specified, only run tests that use this test config')\nparser.add_argument('--generate', action='store_true', help='generate the expected outputs')\n\nargs = parser.parse_args()\n\nif args.log_dir == '':\n    args.log_dir = os.path.join('data/tests/logs', datetime.datetime.now().strftime(\"%Y%m%d_%H%M\"))\n    \nif not os.path.exists(args.log_dir):\n    os.makedirs(args.log_dir)\n\nprint(args)\n\n# wrapper for launching test processes\ndef run_test(module, model, config, args=None, log_dir=None):\n    config = os.path.join('data/tests', config)\n    cmd = f\"python3 tests/{module} --model {model} --config {config}\"\n    \n    if args:\n        cmd += ' ' + args\n       \n    print(\"\\nrunning test:\\n\\t$\", cmd, \"\\n\")  \n\n    if log_dir:\n        tee = f\"tee {os.path.join(log_dir, os.path.splitext(os.path.basename(module))[0])}_{model}.txt\"\n        cmd = f\"mkfifo pipe; {tee} < pipe & {cmd} > pipe; code=$?; rm pipe; exit $code\" # https://stackoverflow.com/a/1221844\n\n    results = subprocess.run(cmd, shell=True)\n    \n    if results.returncode == 0:\n        status = 'PASSED'\n    elif results.returncode == 127:\n        status = 'GENERATED'\n    else:\n        status = 'FAILED'\n        \n    print(f\"\\n{status} TEST {module} ({model}) - return code {results.returncode}\\n\")\n    return status\n    \n# load the config containing all the tests\nwith open(args.tests) as config_file:\n    test_config = json.load(config_file)\n\n# filter the tests if requested\ndef filter_test(test):\n    if args.model != '' and args.model != test['model']:\n        return False\n        \n    if args.module != '' and args.module != test['module']:\n        return False\n        \n    if args.config != '' and args.config != test['config']:\n        return False\n        \n    return True\n        \ntest_config = [test for test in test_config if filter_test(test)]\n\n# run the tests\nfor test in test_config:\n    test_args = test.get('args', '')\n    \n    if args.generate:\n        test_args += ' --generate'\n        \n    status = run_test(test['module'], test['model'], test['config'], test_args, args.log_dir)\n    \n    # if the test needed to generate the expected outputs, run it again\n    if status == 'GENERATED':\n        print('generated expected outputs, running test again...')\n        status = run_test(test['module'], test['model'], test['config'], test.get('args'), args.log_dir)\n     \n    test['status'] = status\n\n# test summary\npassed = 0\n\nprint('')\nprint('----------------------------------------------------')\nprint(' TEST SUMMARY')\nprint('----------------------------------------------------')\n\nfor test in test_config:\n    test_str = f\"{test['module']} ({test['model']})\"\n    print(f\"{test_str:<40} {test['status']}\")\n    \n    if test['status'] == 'PASSED':\n        passed += 1\n        \nprint(f\"\\npassed {passed} of {len(test_config)} tests\")\nprint(f\"saved logs to {args.log_dir}\")\n"
  },
  {
    "path": "tests/test_asr.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nimport os\nimport sys\nimport json\nimport nltk\nimport logging\n\nfrom jetson_voice import ASR, AudioInput, ConfigArgParser\n\n\nparser = ConfigArgParser()\n\nparser.add_argument('--model', default='quartznet', type=str, help='path to model, service name, or json config file')\nparser.add_argument('--config', type=str, required=True, help='path to test config file')\nparser.add_argument('--threshold', type=int, default=0, help='threshold for comparing actual vs expected outputs')\nparser.add_argument('--generate', action='store_true', help='generate the expected outputs')\n\nargs = parser.parse_args()\nprint(args)\n\n\nprint('')\nprint('----------------------------------------------------')\nprint(' RUNNING TEST (ASR)')\nprint('----------------------------------------------------')\nprint(f'   model:  {args.model}')\nprint(f'   config: {args.config}')\nprint('')\n\n# load test config\nwith open(args.config) as config_file:\n    test_config = json.load(config_file)\n\n# load the model\nasr = ASR(args.model)\n\n# list of (passed, num_outputs) tuples\ntest_results = []\n\n# run tests\nfor test in test_config:\n    stream = AudioInput(wav=test['wav'], \n                         sample_rate=asr.sample_rate, \n                         chunk_size=asr.chunk_size)\n\n    outputs = []\n    \n    for samples in stream:\n        output = asr(samples)\n        \n        if asr.classification:\n            print(f\"class '{output[0]}' ({output[1]:.3f})\")\n            outputs.append(output[0])\n        else:\n            for transcript in output:\n                print(transcript['text'])\n                \n                if transcript['end']:\n                    print('')\n                    outputs.append(transcript['text'])\n\n    if not asr.classification:\n        if not transcript['end']: # pick up the last transcript\n            outputs.append(transcript['text'])\n            \n    if 'outputs' not in test:\n        test['outputs'] = {}\n    \n    if args.model not in test['outputs']:\n        args.generate = True\n        \n    if args.generate:\n        test['outputs'][args.model] = outputs\n    else:\n        expected_outputs = test['outputs'][args.model]\n        \n        if len(outputs) != len(expected_outputs):\n            logging.error(f\"failed test '{test['wav']}' - got {len(outputs)} outputs (expected {len(expected_outputs)})\")\n            test_results.append((0, len(expected_outputs)))\n            continue\n        \n        passed = 0\n        \n        for i in range(len(expected_outputs)):\n            similarity = nltk.edit_distance(expected_outputs[i], outputs[i])\n            \n            if similarity > args.threshold:\n                logging.error(f\"failed test '{test['wav']}' - similarity {similarity} exceeded threshold of {args.threshold}\")\n                logging.error( \"  expected:  '{expected_outputs[i]}'\")\n                logging.error( \"  actual:    '{outputs[i]}'\")\n            else:\n                passed += 1\n                \n        test_results.append((passed, len(expected_outputs)))\n\nif args.generate:\n    print('')\n    logging.info(f\"generated expected outputs, saving to '{args.config}'\")\n    \n    with open(args.config, 'w') as config_file:\n        json.dump(test_config, config_file, indent=3)\n        \n    sys.exit(127)\n\n# test summary\npassed_tests = 0\npassed_outputs = 0\ntotal_outputs = 0\n\nfor passed, num_outputs in test_results:\n    if passed == num_outputs:\n        passed_tests += 1\n        \n    passed_outputs += passed\n    total_outputs += num_outputs\n\nprint('')\nprint('----------------------------------------------------')\nprint(' TEST RESULTS (ASR)')\nprint('----------------------------------------------------')\nprint(f'   model:  {args.model}')\nprint(f'   config: {args.config}')\nprint(f'   passed: {passed_tests} / {len(test_config)} audio files')\nprint(f'           {passed_outputs} / {total_outputs} outputs')\nprint('')\n\nif passed_tests != len(test_config):\n    logging.error(f\"failed test '{args.config}' with model '{args.model}'\")\n    sys.exit(1)\n"
  },
  {
    "path": "tests/test_nlp.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nimport os\nimport sys\nimport json\nimport nltk\nimport pprint\nimport logging\n\nfrom jetson_voice import NLP, ConfigArgParser\n\n\nparser = ConfigArgParser()\n\nparser.add_argument('--model', default='distilbert_qa_128', type=str, help='path to model, service name, or json config file')\nparser.add_argument('--config', type=str, required=True, help='path to test config file')\nparser.add_argument('--threshold', type=int, default=0, help='threshold for comparing actual vs expected outputs')\nparser.add_argument('--generate', action='store_true', help='generate the expected outputs')\n\nargs = parser.parse_args()\nprint(args)\n\nprint('')\nprint('----------------------------------------------------')\nprint(f' RUNNING TEST (NLP)')\nprint('----------------------------------------------------')\nprint(f'   model:  {args.model}')\nprint(f'   config: {args.config}')\nprint('')\n\n# load test config\nwith open(args.config) as config_file:\n    test_config = json.load(config_file)\n\n# load the model\nmodel = NLP(args.model)\ntype = model.config.type\n\n\"\"\"\nif args.type == 'intent_slot':\n    model = IntentSlot(args.model)\nelif args.type == 'qa':\n    model = QuestionAnswer(args.model)\nelif args.type == 'text_classification':\n    model = TextClassification(args.model)\nelif args.type == 'token_classification':\n    model = TokenClassification(args.model)\n\"\"\"\n   \n# list of (passed, num_outputs) tuples\ntest_results = []\n\n# run tests\nfor test in test_config:\n    outputs = []\n    \n    if type == 'intent_slot':\n        for query in test['queries']:\n            results = model(query)\n            \n            print('')\n            print('query:', query, '\\n')\n            pprint.pprint(results)\n            print('')\n            \n            result_str = results['intent']\n            \n            for slot in results['slots']:\n                result_str += f\" {slot['slot']}={slot['text']}\"\n                \n            outputs.append(result_str)\n            \n    elif type == 'qa':\n        for question in test['questions']:\n            query = {\n                'question': question,\n                'context': test['context']\n            }\n            \n            answer = model(query, top_k=1)\n            \n            print('\\n')\n            print('context:', query['context'])\n            print('')\n            print('question:', query['question'])\n            print('')\n            print('answer:', answer['answer'])\n            print('score: ', answer['score'])\n            \n            outputs.append(answer['answer'])\n    \n    elif type == 'text_classification':\n        for query in test['queries']:\n            results = model(query)\n            \n            print('')\n            print('query:', query, '\\n')\n            pprint.pprint(results)\n            print('')\n            \n            outputs.append(results['label'])\n    \n    elif type == 'token_classification':\n        for query in test['queries']:\n            results = model(query)\n            result_str = model.tag_string(query, results)\n            \n            print('')\n            print('query:', query, '\\n')\n            print(model.tag_string(query, results, scores=True))\n            print('')\n            \n            outputs.append(result_str)\n            \n    if 'outputs' not in test:\n        test['outputs'] = {}\n    \n    if args.model not in test['outputs']:\n        args.generate = True\n        \n    if args.generate:\n        test['outputs'][args.model] = outputs\n    else:\n        expected_outputs = test['outputs'][args.model]\n        \n        if len(outputs) != len(expected_outputs):\n            logging.error(f\"failed test '{test['wav']}' - got {len(outputs)} outputs (expected {len(expected_outputs)})\")\n            test_results.append((0, len(expected_outputs)))\n            continue\n        \n        passed = 0\n        \n        for i in range(len(expected_outputs)):\n            similarity = nltk.edit_distance(expected_outputs[i], outputs[i])\n            \n            if similarity > args.threshold:\n                logging.error(f\"failed test - similarity {similarity} exceeded threshold of {args.threshold}\")\n                logging.error( \"  expected:  '{expected_outputs[i]}'\")\n                logging.error( \"  actual:    '{outputs[i]}'\")\n            else:\n                passed += 1\n                \n        test_results.append((passed, len(expected_outputs)))\n\nif args.generate:\n    print('')\n    logging.info(f\"generated expected outputs, saving to '{args.config}'\")\n    \n    with open(args.config, 'w') as config_file:\n        json.dump(test_config, config_file, indent=3)\n        \n    sys.exit(127)\n\n# test summary\npassed_tests = 0\npassed_outputs = 0\ntotal_outputs = 0\n\nfor passed, num_outputs in test_results:\n    if passed == num_outputs:\n        passed_tests += 1\n        \n    passed_outputs += passed\n    total_outputs += num_outputs\n\nprint('')\nprint('----------------------------------------------------')\nprint(f' TEST RESULTS (NLP)')\nprint('----------------------------------------------------')\nprint(f'   model:  {args.model}')\nprint(f'   config: {args.config}')\nprint(f'   type:   {type}')\nprint(f'   passed: {passed_tests} / {len(test_config)} tests')\nprint(f'           {passed_outputs} / {total_outputs} queries')\nprint('')\n\nif passed_tests != len(test_config):\n    logging.error(f\"failed test '{args.config}' with model '{args.model}'\")\n    sys.exit(1)\n"
  },
  {
    "path": "tests/test_tts.py",
    "content": "#!/usr/bin/env python3\n# coding: utf-8\n\nimport os\nimport sys\nimport json\nimport librosa\nimport logging\nimport datetime\n\nfrom jetson_voice import TTS, ConfigArgParser\nfrom soundfile import SoundFile\n\nparser = ConfigArgParser()\n\nparser.add_argument('--model', default='fastpitch_hifigan', type=str, help='path to model, service name, or json config file')\nparser.add_argument('--config', type=str, required=True, help='path to test config file')\nparser.add_argument('--rms-threshold', type=float, default=0.005, help='threshold for comparing actual vs expected RMS')\nparser.add_argument('--length-threshold', type=float, default=0.1, help='threshold for comparing actual vs expected audio length (in seconds)')\nparser.add_argument('--generate', action='store_true', help='generate the expected outputs')\nparser.add_argument(\"--output-dir\", default='', help='output directory to save generated audio')\n\nargs = parser.parse_args()\n\nif args.output_dir == '':\n    args.output_dir = os.path.join('data/tests/tts', args.model, datetime.datetime.now().strftime(\"%Y%m%d_%H%M\"))\n    \nif not os.path.exists(args.output_dir):\n    os.makedirs(args.output_dir)\n    \nprint(args)\n\nprint('')\nprint('----------------------------------------------------')\nprint(' RUNNING TEST (TTS)')\nprint('----------------------------------------------------')\nprint(f'   model:  {args.model}')\nprint(f'   config: {args.config}')\nprint('')\n\n# load test config\nwith open(args.config) as config_file:\n    test_config = json.load(config_file)\n\n# load the model\ntts = TTS(args.model)\n\n# list of (passed, num_outputs) tuples\npassed = 0\n\n# run tests\nfor idx, test in enumerate(test_config):\n    audio = tts(test['text'])\n    \n    wav_path = os.path.join(args.output_dir, f\"{idx}.wav\")\n    wav = SoundFile(wav_path, mode='w', samplerate=tts.sample_rate, channels=1)\n    wav.write(audio)\n    wav.close()\n    \n    actual_length = len(audio) / tts.sample_rate\n    actual_rms = float(librosa.feature.rms(y=audio, frame_length=len(audio), center=False)[0][0])\n    \n    print(f\"'{test['text']}'\")\n    print(f\"audio length = {actual_length}s, RMS = {actual_rms}\")\n    print(f\"saved audio to '{wav_path}'\\n\")\n    \n    if 'outputs' not in test:\n        test['outputs'] = {}\n    \n    if args.model not in test['outputs']:\n        args.generate = True\n        \n    if args.generate:\n        test['outputs'][args.model] = (actual_length, actual_rms)\n    else:\n        expected_length, expected_rms = test['outputs'][args.model]\n        \n        length_diff = abs(expected_length - actual_length)\n        rms_diff = abs(expected_rms - actual_rms)\n        \n        if length_diff > args.length_threshold:\n            logging.error(f\"failed test - length difference of {length_diff}s exceeded threshold of {args.length_threshold} (actual={actual_length}s, expected={expected_length}s)\")\n            logging.error(f\"              '{test['text']}'\")\n            continue\n            \n        if rms_diff > args.rms_threshold:\n            logging.error(f\"failed test - RMS difference of {rms_diff} exceeded threshold of {args.rms_threshold} (actual={actual_rms}, expected={expected_rms})\")\n            logging.error(f\"              '{test['text']}'\")\n            continue\n        \n        passed += 1\n\nif args.generate:\n    print('')\n    logging.info(f\"generated expected outputs, saving to '{args.config}'\")\n    \n    with open(args.config, 'w') as config_file:\n        json.dump(test_config, config_file, indent=3)\n        \n    sys.exit(127)\n\n# test summary\nprint('')\nprint('----------------------------------------------------')\nprint(' TEST RESULTS (TTS)')\nprint('----------------------------------------------------')\nprint(f'   model:  {args.model}')\nprint(f'   config: {args.config}')\nprint(f'   passed: {passed} / {len(test_config)}')\nprint('')\n\nif passed != len(test_config):\n    logging.error(f\"failed test '{args.config}' with model '{args.model}'\")\n    sys.exit(1)\n"
  }
]